model_runner.py 43.5 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
13
from vllm.distributed.parallel_state import (
14
    get_dcp_group,
15
16
17
    get_pp_group,
    prepare_communication_buffer_for_model,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
21
from vllm.multimodal import MULTIMODAL_REGISTRY
22
from vllm.sequence import IntermediateTensors
23
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
27
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
28
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
29
from vllm.v1.worker.gpu.async_utils import AsyncOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
from vllm.v1.worker.gpu.attn_utils import (
    build_attn_metadata,
32
    build_slot_mappings_by_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35
    get_kv_cache_spec,
    init_attn_backend,
    init_kv_cache,
36
    prepare_dcp_local_seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
37
38
)
from vllm.v1.worker.gpu.block_table import BlockTables
39
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
41
from vllm.v1.worker.gpu.dp_utils import (
42
    get_cudagraph_and_dp_padding,
43
44
    make_num_tokens_across_dp,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
47
from vllm.v1.worker.gpu.input_batch import (
    InputBatch,
    InputBuffers,
48
    combine_sampled_and_draft_tokens,
49
    expand_idx_mapping,
50
    get_num_sampled_and_rejected,
51
    post_update,
52
53
    prepare_pos_seq_lens,
    prepare_prefill_inputs,
Woosuk Kwon's avatar
Woosuk Kwon committed
54
)
55
56
57
58
59
from vllm.v1.worker.gpu.kv_connector import (
    NO_OP_KV_CONNECTOR,
    KVConnector,
    get_kv_connector,
)
60
from vllm.v1.worker.gpu.lora_utils import LoraState
61
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
62
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
63
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
64
from vllm.v1.worker.gpu.sample.output import SamplerOutput
65
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
66
from vllm.v1.worker.gpu.sample.sampler import Sampler
67
from vllm.v1.worker.gpu.spec_decode import init_speculator
68
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
69
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
70
from vllm.v1.worker.gpu.states import RequestState
71
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
Woosuk Kwon's avatar
Woosuk Kwon committed
72
73
74
75
76
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

logger = init_logger(__name__)


77
class GPUModelRunner(LoRAModelRunnerMixin):
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config

        self.device = device
        self.dtype = self.model_config.dtype
        self.kv_cache_dtype = self.dtype
        if self.cache_config.cache_dtype != "auto":
            # Quantized KV cache.
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                self.cache_config.cache_dtype
            ]
        self.is_pooling_model = False

        self.vocab_size = self.model_config.get_vocab_size()
        self.max_model_len = self.model_config.max_model_len
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
        self.max_num_reqs = self.scheduler_config.max_num_seqs
108
        self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
Woosuk Kwon's avatar
Woosuk Kwon committed
109

110
        # Multimodal
111
112
113
114
115
116
117
118
119
120
121
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
            self.model_config
        )
        if self.supports_mm_inputs:
            self.encoder_runner = EncoderRunner(
                max_num_tokens=self.max_num_tokens,
                hidden_size=self.inputs_embeds_size,
                dtype=self.dtype,
                device=self.device,
            )
122
123
124
125
        self.uses_mrope = self.model_config.uses_mrope
        if self.uses_mrope:
            self.mrope_states = MRopeState(
                max_num_reqs=self.max_num_reqs,
126
                max_num_tokens=self.max_num_tokens,
127
128
129
130
                max_model_len=self.max_model_len,
                device=self.device,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
134
        self.use_async_scheduling = self.scheduler_config.async_scheduling
        self.output_copy_stream = torch.cuda.Stream(self.device)
        self.output_copy_event = torch.cuda.Event()

135
136
137
        if self.speculative_config is not None:
            self.do_spec_decode = True
            self.num_speculative_steps = self.speculative_config.num_speculative_tokens
138
            self.speculator = init_speculator(self.vllm_config, self.device)
139
140
141
        else:
            self.do_spec_decode = False
            self.num_speculative_steps = 0
142
            self.speculator = None
143

Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146
147
        self.req_states = RequestState(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
148
            num_speculative_steps=self.num_speculative_steps,
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151
152
153
154
155
156
            vocab_size=self.vocab_size,
            device=self.device,
        )
        self.input_buffers = InputBuffers(
            max_num_reqs=self.max_num_reqs,
            max_num_tokens=self.max_num_tokens,
            device=self.device,
        )
157
158
159
160
        self.sampler = Sampler(
            max_num_reqs=self.max_num_reqs,
            vocab_size=self.vocab_size,
            device=self.device,
161
            req_states=self.req_states,
162
            logprobs_mode=self.model_config.logprobs_mode,
163
            num_speculative_tokens=self.num_speculative_steps + 1,
164
        )
165
        self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167

        # CUDA graphs.
168
169
170
        self.cudagraph_manager = CudaGraphManager(
            self.vllm_config, self.uses_mrope, self.device
        )
171
172
173
174
        # Structured outputs worker.
        self.structured_outputs_worker = StructuredOutputsWorker(
            max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
            vocab_size=self.vocab_size,
175
            device=self.device,
176
        )
177
178
        # LoRA-related workers.
        self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
179

180
181
182
183
        # Draft tokens propagation - for spec-dec + struct outputs.
        self.draft_tokens_handler = DraftTokensHandler(self.device)

        # KV Connector if configured.
184
185
        self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR

186
187
        # Pipeline parallelism.
        self.use_pp = self.parallel_config.pipeline_parallel_size > 1
188
189
190
191
192
193
        if self.use_pp:
            self.is_first_pp_rank = get_pp_group().is_first_rank
            self.is_last_pp_rank = get_pp_group().is_last_rank
        else:
            self.is_first_pp_rank = True
            self.is_last_pp_rank = True
194

195
196
197
198
    def update_max_model_len(self, max_model_len: int) -> None:
        self.max_model_len = max_model_len
        self.req_states.max_model_len = max_model_len

199
200
    @staticmethod
    def get_supported_tasks() -> tuple[str]:
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        return ("generate",)

    def load_model(self, *args, **kwargs) -> None:
        time_before_load = time.perf_counter()
        with DeviceMemoryProfiler() as m:
            model_loader = get_model_loader(self.vllm_config.load_config)
            logger.info("Loading model from scratch...")

            self.model = model_loader.load_model(
                vllm_config=self.vllm_config,
                model_config=self.vllm_config.model_config,
            )
            if self.lora_config:
                self.model = self.load_lora_model(
215
                    self.model, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
216
                )
217
218
            if self.do_spec_decode:
                self.speculator.load_model(self.model)
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
221
222
        time_after_load = time.perf_counter()

        self.model_memory_usage = m.consumed_memory
        logger.info(
223
224
            "Model loading took %s GiB and %.6f seconds",
            format_gib(m.consumed_memory),
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
227
            time_after_load - time_before_load,
        )

228
229
230
231
232
233
        prepare_communication_buffer_for_model(self.model)
        if self.do_spec_decode:
            speculator_model = getattr(self.speculator, "model", None)
            if speculator_model is not None:
                prepare_communication_buffer_for_model(speculator_model)

Woosuk Kwon's avatar
Woosuk Kwon committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    def get_model(self) -> nn.Module:
        return self.model

    def get_kv_cache_spec(self):
        return get_kv_cache_spec(self.vllm_config)

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        kv_cache_config = deepcopy(kv_cache_config)
        self.kv_cache_config = kv_cache_config
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
        ]

        self.block_tables = BlockTables(
            block_sizes=block_sizes,
            max_num_reqs=self.max_num_reqs,
            max_num_batched_tokens=self.max_num_tokens,
            max_model_len=self.max_model_len,
            device=self.device,
254
255
256
            cp_kv_cache_interleave_size=(
                self.parallel_config.cp_kv_cache_interleave_size
            ),
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
        )

        self.attn_backends, self.attn_metadata_builders = init_attn_backend(
260
            self.kv_cache_config, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
261
        )
262
        check_attention_cp_compatibility(self.vllm_config)
263
264
265
266
267
268
269
        if self.do_spec_decode:
            # HACK(woosuk)
            self.speculator.set_attn(
                self.kv_cache_config,
                self.attn_metadata_builders,
                self.block_tables,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
270
271

        self.kv_caches: list[torch.Tensor] = []
272
        kv_caches_dict = init_kv_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
273
274
275
276
277
278
            self.kv_caches,
            self.compilation_config.static_forward_context,
            self.kv_cache_config,
            self.attn_backends,
            self.device,
        )
279
280
        self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)

Woosuk Kwon's avatar
Woosuk Kwon committed
281
282
283
284
285
286
287
288
        # Attention groups are not supported.
        self.attn_groups = []  # type: ignore

    def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
        block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
        slot_mappings = self.block_tables.get_dummy_slot_mappings(
            input_batch.num_tokens
        )
289
290
291
        slot_mappings_by_layer = build_slot_mappings_by_layer(
            slot_mappings, self.kv_cache_config
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
294
295
        attn_metadata = build_attn_metadata(
            attn_metadata_builders=self.attn_metadata_builders,
            num_reqs=input_batch.num_reqs,
            num_tokens=input_batch.num_tokens,
296
297
            query_start_loc_gpu=input_batch.query_start_loc,
            query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
298
            max_query_len=input_batch.num_scheduled_tokens.max().item(),
299
            seq_lens=input_batch.seq_lens,
300
            max_seq_len=self.max_model_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
303
            block_tables=block_tables,
            slot_mappings=slot_mappings,
            kv_cache_config=self.kv_cache_config,
304
            dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306
        )
        input_batch.attn_metadata = attn_metadata
307
        input_batch.slot_mappings = slot_mappings_by_layer
Woosuk Kwon's avatar
Woosuk Kwon committed
308
309
310

    @torch.inference_mode()
    def _dummy_run(
311
        self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
312
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
313
        # Create a dummy scheduler output.
Woosuk Kwon's avatar
Woosuk Kwon committed
314
        num_reqs = min(num_tokens, self.max_num_reqs)
315
316
317
318
        num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
        num_tokens_per_request[-1] += num_tokens % num_reqs
        assert sum(num_tokens_per_request) == num_tokens
        num_scheduled_tokens = {
319
            f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
320
321
322
323
324
        }
        dummy_scheduler_output = SchedulerOutput.make_empty()
        dummy_scheduler_output.total_num_scheduled_tokens = num_tokens
        dummy_scheduler_output.num_scheduled_tokens = num_scheduled_tokens

325
326
327
        # Disable any use of KVConnector for dummy runs.
        self.kv_connector.set_disabled(True)

328
329
        # For non-first PP ranks, create dummy intermediate_tensors.
        intermediate_tensors = None
330
        if not self.is_first_pp_rank:
331
332
333
334
335
336
            intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=num_tokens,
                dtype=self.model_config.dtype,
                device=self.device,
            )

337
338
        # Execute the model.
        self.execute_model(
339
340
341
342
            dummy_scheduler_output,
            intermediate_tensors=intermediate_tensors,
            dummy_run=True,
            skip_attn_for_dummy_run=skip_attn,
343
        )
344
        self.kv_connector.set_disabled(False)
345
346

        # Non-last PP ranks don't produce output for sampling.
347
        if not self.is_last_pp_rank:
348
349
            return None, None

350
        assert self.execute_model_state is not None
351
        hidden_states, input_batch, _ = self.execute_model_state
352
        assert hidden_states is not None  # Last PP rank always has hidden_states
353
        sample_hidden_states = hidden_states[input_batch.logits_indices]
Woosuk Kwon's avatar
Woosuk Kwon committed
354
355
356
        return hidden_states, sample_hidden_states

    @torch.inference_mode()
357
    def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
358
359
        num_reqs = hidden_states.shape[0]
        logits = self.model.compute_logits(hidden_states)
360
361
362
        idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
        idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
        pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
363
364
365
366
        dummy_input_ids = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
        expanded_local_pos = torch.zeros(
            num_reqs, dtype=torch.int32, device=self.device
        )
367
368
369
        # NOTE(woosuk): During the initial memory profiling, the sampler may skip
        # top_k, top_p, and logprobs, using less GPU memory than what is possible
        # during actual execution.
370
371
372
373
374
375
376
377
378
        self.sampler(
            logits,
            idx_mapping,
            idx_mapping_np,
            idx_mapping_np,
            pos,
            dummy_input_ids,
            expanded_local_pos,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
379
380
381
382

    @torch.inference_mode()
    def profile_run(self) -> None:
        hidden_states, sample_hidden_states = self._dummy_run(
383
            self.max_num_tokens, skip_attn=True
Woosuk Kwon's avatar
Woosuk Kwon committed
384
        )
385

386
        # Only run sampler on last PP rank (non-last ranks return None).
387
        if self.is_last_pp_rank:
388
389
            assert sample_hidden_states is not None
            self._dummy_sampler_run(sample_hidden_states)
390
391
392
393
394
395
396
397
398
399
400
401

            if self.do_spec_decode:
                num_tokens_across_dp = make_num_tokens_across_dp(
                    self.parallel_config.data_parallel_size, self.max_num_tokens
                )
                self.speculator.run_model(
                    self.max_num_tokens,
                    attn_metadata=None,
                    slot_mappings=None,
                    num_tokens_across_dp=num_tokens_across_dp,
                )

Woosuk Kwon's avatar
Woosuk Kwon committed
402
403
404
405
406
        torch.cuda.synchronize()
        del hidden_states, sample_hidden_states
        gc.collect()

    def reset_mm_cache(self) -> None:
407
408
        if self.supports_mm_inputs:
            self.encoder_runner.reset_mm_cache()
409
410

    def reset_encoder_cache(self) -> None:
411
412
        if self.supports_mm_inputs:
            self.encoder_runner.reset_encoder_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426

    def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
        # SP is not supported yet.
        return num_scheduled_tokens

    @torch.inference_mode()
    def capture_model(self) -> int:
        if not self.cudagraph_manager.needs_capture():
            logger.warning(
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
                "ensure `cudagraph_mode` was not manually set to `NONE`"
            )
            return 0

427
428
429
430
431
432
433
434
        # TODO (zhanqiu): support CUDA graph for PP.
        if self.use_pp:
            logger.warning_once(
                "Skipping CUDA graph capture because pipeline parallel is "
                "enabled. Pipeline parallel is currently eager-only.",
            )
            return 0

Woosuk Kwon's avatar
Woosuk Kwon committed
435
        start_time = time.perf_counter()
436
        gc.collect()
437
        torch.cuda.empty_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
438
439
440
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

        with self.maybe_setup_dummy_loras(self.lora_config):
441
442
443
            mrope_positions = None
            if self.uses_mrope:
                mrope_positions = self.mrope_states.mrope_positions
444
445
446
            inputs_embeds = None
            if self.supports_mm_inputs:
                inputs_embeds = self.encoder_runner.inputs_embeds
Woosuk Kwon's avatar
Woosuk Kwon committed
447
448
449
            self.cudagraph_manager.capture(
                model=self.model,
                input_buffers=self.input_buffers,
450
                mrope_positions=mrope_positions,
451
                inputs_embeds=inputs_embeds,
Woosuk Kwon's avatar
Woosuk Kwon committed
452
453
454
455
                block_tables=self.block_tables,
                attn_metadata_builders=self.attn_metadata_builders,
                kv_cache_config=self.kv_cache_config,
            )
456
457
            if self.do_spec_decode:
                self.speculator.capture_model()
Woosuk Kwon's avatar
Woosuk Kwon committed
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info(
            "Graph capturing finished in %.0f secs, took %.2f GiB",
            elapsed_time,
            cuda_graph_size / (1 << 30),
        )
        return cuda_graph_size

    def warmup_for_prefill(self) -> None:
        # For FlashInfer, we would like to execute a dummy prefill run
        # to trigger JIT compilation.
        if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
            self._dummy_run(self.max_num_tokens, skip_attn=False)
            torch.cuda.synchronize()

478
    def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
479
        finished_req_ids = scheduler_output.finished_req_ids
480
481
482
        preempted_req_ids = scheduler_output.preempted_req_ids
        if preempted_req_ids:
            finished_req_ids = finished_req_ids.union(preempted_req_ids)
483
        for req_id in finished_req_ids:
Woosuk Kwon's avatar
Woosuk Kwon committed
484
            self.req_states.remove_request(req_id)
485
486
            if self.supports_mm_inputs:
                self.encoder_runner.remove_request(req_id)
487
            self.prompt_logprobs_worker.remove_request(req_id)
488
            self.lora_state.remove_request(req_id)
489

490
    def free_states(self, scheduler_output: SchedulerOutput) -> None:
491
492
493
        if self.supports_mm_inputs:
            for mm_hash in scheduler_output.free_encoder_mm_hashes:
                self.encoder_runner.free_encoder_cache(mm_hash)
Woosuk Kwon's avatar
Woosuk Kwon committed
494

495
    def add_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
496
        for new_req_data in scheduler_output.scheduled_new_reqs:
497
498
499
            assert new_req_data.prompt_token_ids is not None
            assert new_req_data.prefill_token_ids is not None
            assert new_req_data.sampling_params is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
500
            req_id = new_req_data.req_id
501
            prompt_len = len(new_req_data.prompt_token_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
502
503
            self.req_states.add_request(
                req_id=req_id,
504
                prompt_len=prompt_len,
505
                all_token_ids=new_req_data.prefill_token_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
506
507
508
                num_computed_tokens=new_req_data.num_computed_tokens,
            )
            req_index = self.req_states.req_id_to_index[req_id]
509

510
511
512
            if self.supports_mm_inputs:
                self.encoder_runner.add_request(req_id, new_req_data.mm_features)

513
514
515
516
517
518
            # Pre-compute M-RoPE positions for prefill.
            if self.uses_mrope:
                self.mrope_states.init_prefill_mrope_positions(
                    req_index,
                    self.model,  # type: ignore
                    new_req_data.prefill_token_ids,
519
                    mm_features=new_req_data.mm_features,
520
521
                )

522
523
524
            self.block_tables.append_block_ids(
                req_index, new_req_data.block_ids, overwrite=True
            )
525
526
527
            self.sampler.add_request(
                req_index, prompt_len, new_req_data.sampling_params
            )
528
529
530
            self.prompt_logprobs_worker.add_request(
                req_id, req_index, new_req_data.sampling_params
            )
531
            self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
Woosuk Kwon's avatar
Woosuk Kwon committed
532

533
534
        if scheduler_output.scheduled_new_reqs:
            self.req_states.apply_staged_writes()
535
            self.sampler.apply_staged_writes()
536
537
538
539
            if self.uses_mrope:
                self.mrope_states.apply_staged_writes()

    def update_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
540
        # Add new blocks for the existing requests.
541
542
        reqs = scheduler_output.scheduled_cached_reqs
        for req_new_block_ids, req_id in zip(reqs.new_block_ids, reqs.req_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
543
            if req_new_block_ids is not None:
544
                req_index = self.req_states.req_id_to_index[req_id]
545
546
547
                self.block_tables.append_block_ids(
                    req_index, req_new_block_ids, overwrite=False
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
548
549

    def prepare_inputs(
550
        self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
Woosuk Kwon's avatar
Woosuk Kwon committed
551
552
553
    ) -> InputBatch:
        num_tokens = scheduler_output.total_num_scheduled_tokens
        assert num_tokens > 0
554
555
        num_tokens_per_req = scheduler_output.num_scheduled_tokens
        num_reqs = len(num_tokens_per_req)
Woosuk Kwon's avatar
Woosuk Kwon committed
556
557
558

        # Decode first, then prefill.
        # batch_idx -> req_id
559
560
561
        req_ids = sorted(num_tokens_per_req, key=num_tokens_per_req.get)  # type: ignore[arg-type]
        numtoks_iter = map(num_tokens_per_req.get, req_ids)
        num_scheduled_tokens = np.fromiter(numtoks_iter, dtype=np.int32, count=num_reqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
562

563
564
        idx_mapping_iter = map(self.req_states.req_id_to_index.get, req_ids)
        idx_mapping_np = np.fromiter(idx_mapping_iter, dtype=np.int32, count=num_reqs)
565
        idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
566

567
        # Get the number of draft tokens for each request.
568
569
        draft_tokens = scheduler_output.scheduled_spec_decode_tokens
        if not draft_tokens:
570
571
572
            # No draft token scheduled (common case).
            total_num_draft_tokens = 0
            total_num_logits = num_reqs
573
            cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
574
575
576
            cu_num_logits = torch.arange(
                num_reqs + 1, device=self.device, dtype=torch.int32
            )
577
            expanded_idx_mapping = idx_mapping
578
579
580
            expanded_local_pos = torch.zeros(
                num_reqs, dtype=torch.int32, device=self.device
            )
581
582
        else:
            num_draft_tokens = np.array(
583
                [len(draft_tokens.get(req_id, ())) for req_id in req_ids],
584
585
586
587
588
                dtype=np.int32,
            )
            total_num_draft_tokens = int(num_draft_tokens.sum())
            total_num_logits = num_reqs + total_num_draft_tokens

589
590
591
592
            num_logits = num_draft_tokens + 1
            cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32)
            cu_num_logits_np[0] = 0
            np.cumsum(num_logits, out=cu_num_logits_np[1:])
593
            cu_num_logits = async_copy_to_gpu(cu_num_logits_np, device=self.device)
594

595
            max_expand_len = self.num_speculative_steps + 1
596
            expanded_idx_mapping, expanded_local_pos = expand_idx_mapping(
597
                idx_mapping, total_num_logits, cu_num_logits, max_expand_len
598
599
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
600
601
602
        # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
        block_tables = self.block_tables.gather_block_tables(idx_mapping)

603
        # Get query_start_loc.
604
605
606
        query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
        query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
607
608
        # Pad for full CUDA graph mode.
        # Some attention backends like FA3 require query_start_loc to be non-decreasing.
609
        query_start_loc_np[num_reqs + 1 :] = num_tokens
610
611
        async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)

612
613
614
        query_start_loc_np = query_start_loc_np[: num_reqs + 1]
        query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
        query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
615
        max_query_len = num_scheduled_tokens.max().item()
616

617
        # Get prefill tokens.
618
        prepare_prefill_inputs(
619
620
621
            self.input_buffers.input_ids,
            self.req_states.next_prefill_tokens,
            idx_mapping,
622
            query_start_loc,
623
            self.req_states.all_token_ids.gpu,
624
            self.req_states.prefill_len.gpu,
625
            self.req_states.num_computed_tokens.gpu,
Woosuk Kwon's avatar
Woosuk Kwon committed
626
627
        )

628
629
630
        # Prepare positions and seq_lens.
        prepare_pos_seq_lens(
            idx_mapping,
631
632
            query_start_loc,
            self.req_states.num_computed_tokens.gpu,
633
634
635
636
637
            self.input_buffers.positions,
            self.input_buffers.seq_lens,
        )
        seq_lens = self.input_buffers.seq_lens[:num_reqs]

638
639
640
641
642
643
644
645
646
647
648
649
650
        dcp_size = self.parallel_config.decode_context_parallel_size
        if dcp_size > 1:
            prepare_dcp_local_seq_lens(
                self.input_buffers.dcp_local_seq_lens,
                seq_lens,
                num_reqs,
                dcp_size=dcp_size,
                dcp_rank=get_dcp_group().rank_in_group,
                cp_kv_cache_interleave_size=(
                    self.parallel_config.cp_kv_cache_interleave_size
                ),
            )

651
652
653
654
655
656
657
658
659
        # Prepare M-RoPE positions.
        if self.uses_mrope:
            self.mrope_states.prepare_mrope_positions(
                idx_mapping,
                query_start_loc,
                self.req_states.prefill_len.gpu,
                self.req_states.num_computed_tokens.gpu,
            )

660
        # Some input token ids are directly read from the last sampled tokens
661
662
        # and draft tokens. Also, get the logits indices to sample tokens from.
        logits_indices = combine_sampled_and_draft_tokens(
663
            self.input_buffers.input_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
664
665
            idx_mapping,
            self.req_states.last_sampled_tokens,
666
            query_start_loc,
667
668
            seq_lens,
            self.req_states.prefill_len.gpu,
669
670
671
            self.req_states.draft_tokens,
            cu_num_logits,
            total_num_logits,
Woosuk Kwon's avatar
Woosuk Kwon committed
672
673
674
675
        )

        # Compute slot mappings: [num_kv_cache_groups, num_tokens]
        slot_mappings = self.block_tables.compute_slot_mappings(
676
677
678
            idx_mapping,
            query_start_loc,
            self.input_buffers.positions[:num_tokens],
Woosuk Kwon's avatar
Woosuk Kwon committed
679
        )
680
681
682
683
        # Layer name -> slot mapping.
        slot_mappings_by_layer = build_slot_mappings_by_layer(
            slot_mappings, self.kv_cache_config
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
684
685
686
687
688
689

        # Layer name -> attention metadata.
        attn_metadata = build_attn_metadata(
            attn_metadata_builders=self.attn_metadata_builders,
            num_reqs=num_reqs,
            num_tokens=num_tokens,
690
            query_start_loc_gpu=query_start_loc,
691
            query_start_loc_cpu=query_start_loc_cpu,
692
            max_query_len=max_query_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
693
            seq_lens=self.input_buffers.seq_lens,
694
            max_seq_len=self.max_model_len,
Woosuk Kwon's avatar
Woosuk Kwon committed
695
696
697
            block_tables=block_tables,
            slot_mappings=slot_mappings,
            kv_cache_config=self.kv_cache_config,
698
            dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
699
700
        )

701
        input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
702
        positions = self.input_buffers.positions[:num_tokens_after_padding]
703
704
        mrope_positions = None
        if self.uses_mrope:
705
706
            mrope_positions = self.mrope_states.mrope_positions
            mrope_positions = mrope_positions[:, :num_tokens_after_padding]
Woosuk Kwon's avatar
Woosuk Kwon committed
707
708
709
710
711
        return InputBatch(
            req_ids=req_ids,
            num_reqs=num_reqs,
            idx_mapping=idx_mapping,
            idx_mapping_np=idx_mapping_np,
712
            expanded_idx_mapping=expanded_idx_mapping,
713
            expanded_local_pos=expanded_local_pos,
Woosuk Kwon's avatar
Woosuk Kwon committed
714
715
716
            num_scheduled_tokens=num_scheduled_tokens,
            num_tokens=num_tokens,
            num_tokens_after_padding=num_tokens_after_padding,
717
            num_draft_tokens=total_num_draft_tokens,
718
            query_start_loc=query_start_loc,
Woosuk Kwon's avatar
Woosuk Kwon committed
719
            query_start_loc_np=query_start_loc_np,
720
            seq_lens=seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
721
722
            input_ids=input_ids,
            positions=positions,
723
            mrope_positions=mrope_positions,
724
            inputs_embeds=None,
Woosuk Kwon's avatar
Woosuk Kwon committed
725
            attn_metadata=attn_metadata,
726
            slot_mappings=slot_mappings_by_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
727
            logits_indices=logits_indices,
728
            cu_num_logits=cu_num_logits,
729
            cu_num_logits_np=cu_num_logits_np,
730
            has_structured_output_reqs=scheduler_output.has_structured_output_requests,
Woosuk Kwon's avatar
Woosuk Kwon committed
731
732
        )

733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
    @torch.inference_mode()
    def get_mm_embeddings(
        self,
        scheduled_encoder_inputs: dict[str, list[int]],
        input_batch: InputBatch,
    ) -> tuple[list[torch.Tensor], torch.Tensor]:
        mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
            scheduled_encoder_inputs
        )
        self.encoder_runner.execute_mm_encoder(self.model, mm_hashes, mm_kwargs)
        mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
            input_batch.req_ids,
            input_batch.num_tokens,
            input_batch.num_scheduled_tokens,
            input_batch.query_start_loc_np,
            self.req_states.prefill_len.np[input_batch.idx_mapping_np],
            self.req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
        )
        return mm_embeds, is_mm_embed

Woosuk Kwon's avatar
Woosuk Kwon committed
753
754
755
756
757
    def sample(
        self,
        hidden_states: torch.Tensor,
        input_batch: InputBatch,
        grammar_output: GrammarOutput | None,
758
    ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
759
        sample_hidden_states = hidden_states[input_batch.logits_indices]
760
        sample_pos = input_batch.positions[input_batch.logits_indices]
761
        input_ids = input_batch.input_ids[input_batch.logits_indices]
Woosuk Kwon's avatar
Woosuk Kwon committed
762
763
764
        logits = self.model.compute_logits(sample_hidden_states)
        if grammar_output is not None:
            # Apply grammar bitmask to the logits in-place.
765
766
767
768
769
770
            self.structured_outputs_worker.apply_grammar_bitmask(
                logits,
                input_batch,
                grammar_output.structured_output_request_ids,
                grammar_output.grammar_bitmask,
            )
771

772
        # Sample tokens and compute logprobs (if needed).
773
774
775
776
        sampler_output = self.sampler(
            logits,
            input_batch.expanded_idx_mapping,
            input_batch.idx_mapping_np,
777
            input_batch.cu_num_logits_np,
778
            sample_pos,
779
780
            input_ids,
            input_batch.expanded_local_pos,
781
        )
782
783
784

        if input_batch.num_draft_tokens == 0:
            # No draft tokens (common case).
785
786
787
            num_sampled = torch.ones(
                input_batch.num_reqs, dtype=torch.int32, device=self.device
            )
788
        else:
789
            # Rejection sampling for spec decoding.
790
791
792
793
794
795
796
            sampled_tokens, num_sampled = rejection_sample(
                sampler_output.sampled_token_ids,
                input_ids,
                input_batch.cu_num_logits,
                self.num_speculative_steps,
            )
            sampler_output.sampled_token_ids = sampled_tokens
797
798
799
800
801
802
803
804
805
806

        # Get the number of sampled and rejected tokens.
        # For chunked prefills, num_sampled and num_rejected are both 0.
        num_sampled, num_rejected = get_num_sampled_and_rejected(
            num_sampled,
            input_batch.seq_lens,
            input_batch.cu_num_logits,
            input_batch.idx_mapping,
            self.req_states.prefill_len.gpu,
        )
807
        return sampler_output, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
808
809
810
811

    def postprocess(
        self,
        input_batch: InputBatch,
812
813
        sampled_tokens: torch.Tensor,
        num_sampled: torch.Tensor,
814
        num_rejected: torch.Tensor,
815
816
    ) -> None:
        # Update the number of computed tokens.
817
        post_update(
818
            input_batch.idx_mapping,
819
            self.req_states.num_computed_tokens.gpu,
820
            self.req_states.last_sampled_tokens,
821
            self.sampler.penalties_state.output_bin_counts,
822
823
            sampled_tokens,
            num_sampled,
824
            num_rejected,
825
            input_batch.query_start_loc,
826
827
            self.req_states.all_token_ids.gpu,
            self.req_states.total_len.gpu,
Woosuk Kwon's avatar
Woosuk Kwon committed
828
        )
829
830

        # Update the number of computed prefill tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
831
        idx_mapping_np = input_batch.idx_mapping_np
832
        computed_prefill = self.req_states.num_computed_prefill_tokens
833
834
835
        computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
        np.minimum(
            computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
Woosuk Kwon's avatar
Woosuk Kwon committed
836
837
        )

838
839
840
841
842
843
844
    @torch.inference_mode()
    def propose_draft(
        self,
        input_batch: InputBatch,
        last_hidden_states: torch.Tensor,
        aux_hidden_states: list[torch.Tensor] | None,
        num_sampled: torch.Tensor,
845
        num_rejected: torch.Tensor,
846
847
848
849
850
851
852
    ) -> torch.Tensor:
        assert self.speculator is not None
        draft_tokens = self.speculator.propose(
            input_batch,
            last_hidden_states,
            aux_hidden_states,
            num_sampled,
853
            num_rejected,
854
855
            self.req_states.last_sampled_tokens,
            self.req_states.next_prefill_tokens,
856
857
            self.sampler.sampling_states.temperature.gpu,
            self.sampler.sampling_states.seeds.gpu,
858
859
860
        )
        return draft_tokens

Woosuk Kwon's avatar
Woosuk Kwon committed
861
862
863
864
    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: SchedulerOutput,
865
        intermediate_tensors: IntermediateTensors | None = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
866
        dummy_run: bool = False,
867
        skip_attn_for_dummy_run: bool = False,
868
    ) -> ModelRunnerOutput | IntermediateTensors | None:
869
870
871
872
873
874
875
876
877
        if not dummy_run:
            # Update the request states.
            self.finish_requests(scheduler_output)
            self.free_states(scheduler_output)
            self.add_requests(scheduler_output)
            self.update_requests(scheduler_output)
            self.block_tables.apply_staged_writes()
            if scheduler_output.total_num_scheduled_tokens == 0:
                # No need to run the model.
878
879
                empty_output = self.kv_connector.no_forward(scheduler_output)
                return empty_output
Woosuk Kwon's avatar
Woosuk Kwon committed
880

881
882
883
884
885
886
887
        # Get the CUDA graph size. None means no CUDA graph is used.
        cudagraph_size = self.cudagraph_manager.get_cudagraph_size(
            scheduler_output.total_num_scheduled_tokens,
            scheduler_output.num_scheduled_tokens.values(),
        )
        use_cudagraph, num_tokens_after_padding, num_tokens_across_dp = (
            get_cudagraph_and_dp_padding(
888
                scheduler_output.total_num_scheduled_tokens,
889
890
891
                cudagraph_size,
                self.parallel_config.data_parallel_size,
                self.parallel_config.data_parallel_rank,
892
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
893
        )
894
895
        if num_tokens_after_padding == 0:
            # All DP ranks have zero tokens to run.
896
897
            empty_output = self.kv_connector.no_forward(scheduler_output)
            return empty_output
898
899
900
901
902

        if not dummy_run:
            # Common case.
            # Prepare all the inputs and copy to the input buffers.
            input_batch = self.prepare_inputs(
903
                scheduler_output, num_tokens_after_padding
904
905
906
            )
            if self.lora_config:
                # Activate LoRA adapters.
907
                lora_inputs = self.lora_state.make_lora_inputs(
908
909
910
                    input_batch.req_ids,
                    input_batch.idx_mapping_np,
                    input_batch.num_scheduled_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
911
                )
912
                self._set_active_loras(*lora_inputs)
913

914
            # Only first PP rank prepares multimodal embeddings.
915
            if self.supports_mm_inputs and self.is_first_pp_rank:
916
917
918
919
920
921
922
923
924
                mm_embeds, is_mm_embed = self.get_mm_embeddings(
                    scheduler_output.scheduled_encoder_inputs, input_batch
                )
                inputs_embeds = self.encoder_runner.get_inputs_embeds(
                    self.model, input_batch.input_ids, mm_embeds, is_mm_embed
                )
                input_batch.inputs_embeds = inputs_embeds[
                    : input_batch.num_tokens_after_padding
                ]
925
        else:
926
            # No actual tokens to run. A dummy run for DP or memory profiling.
927
928
929
930
931
932
933
            num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
            input_batch = InputBatch.make_dummy(
                num_reqs=num_reqs,
                num_tokens=num_tokens_after_padding,
                input_buffers=self.input_buffers,
                device=self.device,
            )
934
935
936
937
            if self.uses_mrope:
                input_batch.mrope_positions = self.mrope_states.mrope_positions[
                    :, :num_tokens_after_padding
                ]
938
939
940
            if not skip_attn_for_dummy_run:
                self.prepare_dummy_attn_metadata(input_batch)
            # FIXME(woosuk): Fix warmup for LoRA.
Woosuk Kwon's avatar
Woosuk Kwon committed
941
942

        # Run model.
943
        if use_cudagraph:
Woosuk Kwon's avatar
Woosuk Kwon committed
944
945
946
            # Run CUDA graph.
            # NOTE(woosuk): Here, we don't need to pass the input tensors,
            # because they are already copied to the CUDA graph input buffers.
947
            self.kv_connector.pre_forward(scheduler_output)
Woosuk Kwon's avatar
Woosuk Kwon committed
948
949
950
951
952
            hidden_states = self.cudagraph_manager.run(
                input_batch.num_tokens_after_padding
            )
        else:
            # Run PyTorch model in eager mode.
953
954
955
            positions = input_batch.positions
            if self.uses_mrope:
                assert input_batch.mrope_positions is not None
956
                positions = input_batch.mrope_positions
957

958
959
960
961
962
963
964
965
966
            if self.is_first_pp_rank:
                input_ids = input_batch.input_ids
                inputs_embeds = input_batch.inputs_embeds
                assert intermediate_tensors is None
            else:
                input_ids = None
                inputs_embeds = None
                assert intermediate_tensors is not None

Woosuk Kwon's avatar
Woosuk Kwon committed
967
968
969
970
            with set_forward_context(
                input_batch.attn_metadata,
                self.vllm_config,
                num_tokens=input_batch.num_tokens_after_padding,
971
972
                # TODO(woosuk): Support piecewise CUDA graph.
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
Woosuk Kwon's avatar
Woosuk Kwon committed
973
                num_tokens_across_dp=num_tokens_across_dp,
974
                slot_mapping=input_batch.slot_mappings,
Woosuk Kwon's avatar
Woosuk Kwon committed
975
            ):
976
                self.kv_connector.pre_forward(scheduler_output)
977
978
979
980
981
982
                hidden_states = self.model(
                    input_ids=input_ids,
                    positions=positions,
                    inputs_embeds=inputs_embeds,
                    intermediate_tensors=intermediate_tensors,
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
983

984
        kv_connector_output = self.kv_connector.post_forward(scheduler_output)
985

986
        if not self.is_last_pp_rank:
987
988
989
990
991
992
993
994
995
            # Non-last PP rank: return IntermediateTensors for sending.
            assert isinstance(hidden_states, IntermediateTensors)
            hidden_states.kv_connector_output = kv_connector_output
            self.execute_model_state = (None, input_batch, kv_connector_output)
            return hidden_states

        assert isinstance(hidden_states, torch.Tensor)
        # Last rank (or no PP): hidden_states is a tensor for sampling.
        self.execute_model_state = (hidden_states, input_batch, kv_connector_output)
Woosuk Kwon's avatar
Woosuk Kwon committed
996
997
998
999
        return None

    @torch.inference_mode()
    def sample_tokens(
1000
        self, grammar_output: GrammarOutput | None
1001
    ) -> AsyncOutput | ModelRunnerOutput | None:
Woosuk Kwon's avatar
Woosuk Kwon committed
1002
        assert self.execute_model_state is not None
1003
        hidden_states, input_batch, kv_connector_output = self.execute_model_state
Woosuk Kwon's avatar
Woosuk Kwon committed
1004
1005
        self.execute_model_state = None  # type: ignore

1006
1007
1008
        # Non-last PP rank: hidden_states is None because this rank produced
        # IntermediateTensors instead of final hidden states. Receive the
        # sampled tokens broadcast by the last rank and update local state.
1009
        if not self.is_last_pp_rank:
1010
            received = pp_receive(
1011
                input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
1012
            )
1013
1014
1015
            assert received is not None
            sampled, num_sampled, num_rejected = received
            self.postprocess(input_batch, sampled, num_sampled, num_rejected)
1016
1017
1018
            return None

        # Last rank: sample tokens
1019
        sampler_output, num_sampled, num_rejected = self.sample(
1020
            hidden_states, input_batch, grammar_output
Woosuk Kwon's avatar
Woosuk Kwon committed
1021
        )
1022
1023
1024

        # Broadcast to non-last PP ranks (handles spec decode multi-token).
        if self.use_pp:
1025
            pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
1026

1027
1028
1029
1030
        prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
            self.model.compute_logits,
            hidden_states,
            input_batch,
1031
            self.req_states.all_token_ids.gpu,
1032
            self.req_states.num_computed_tokens.gpu,
1033
            self.req_states.prompt_len.np,
1034
1035
1036
            self.req_states.prefill_len.np,
            self.req_states.num_computed_prefill_tokens,
        )
1037
1038
1039
1040
1041
1042
1043
1044

        # Prepare the model runner output.
        model_runner_output = ModelRunnerOutput(
            req_ids=input_batch.req_ids,
            # NOTE(woosuk): req_id_to_index is unused in this model runner.
            # Only for compatibility with the existing model runner and scheduler.
            req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
            sampled_token_ids=None,  # type: ignore
1045
            prompt_logprobs_dict=prompt_logprobs_dict,  # type: ignore[arg-type]
1046
            kv_connector_output=kv_connector_output,
1047
1048
1049
1050
        )
        async_output = AsyncOutput(
            model_runner_output=model_runner_output,
            sampler_output=sampler_output,
1051
            num_sampled_tokens=num_sampled,
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
            copy_stream=self.output_copy_stream,
            copy_event=self.output_copy_event,
        )

        # Postprocess results and update request states.
        # NOTE: This is intentionally done after creating the AsyncOutput,
        # ensuring that `copy_event` is recorded before calling postprocess.
        # This sequencing may slightly reduce latency as async D2H copy does not
        # need to wait for the postprocess to finish.
        self.postprocess(
1062
            input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
1063
        )
1064
        if self.do_spec_decode:
1065
            draft_tokens = self.propose_draft(
1066
1067
1068
                input_batch,
                hidden_states,
                None,  # aux_hidden_states
1069
1070
                num_sampled,
                num_rejected,
1071
            )
1072
            self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
1073
            self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
1074
1075
1076
1077

        if self.use_async_scheduling:
            return async_output
        return async_output.get_output()
1078
1079
1080

    def take_draft_token_ids(self) -> DraftTokenIds | None:
        return self.draft_tokens_handler.get_draft_tokens()