model_runner.py 45.7 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""
NOTE: Coding style guide for this file:
This model runner is shared by all models: text and multimodal, generative
and embedding, public and private. As a result, this file must only contain
code that is common to every model. Model-specific behavior belongs in the
appropriate model-specific files.

In other words:
* Be paranoid about changing this file. It should remain stable.
* Be even more paranoid about adding new lines. It should remain minimal.

Even for shared features (for example, different parallelism modes), keep the
complexity out of this path. The less common the feature, the more it should be
hidden. Prefer utility functions defined elsewhere and call them from here,
instead of embedding feature-specific logic directly.
"""

20
import functools
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24
25
26
27
28
29
30
import gc
import time
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
31
from vllm.distributed.parallel_state import (
32
    get_dcp_group,
33
34
35
    get_pp_group,
    prepare_communication_buffer_for_model,
)
36
from vllm.forward_context import BatchDescriptor, set_forward_context
Woosuk Kwon's avatar
Woosuk Kwon committed
37
38
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
39
from vllm.multimodal import MULTIMODAL_REGISTRY
40
from vllm.sequence import IntermediateTensors
41
from vllm.tasks import SupportedTask
42
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44
45
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
46
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
47
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
48
from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput
Woosuk Kwon's avatar
Woosuk Kwon committed
49
from vllm.v1.worker.gpu.attn_utils import (
50
    build_slot_mappings_by_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53
54
55
    get_kv_cache_spec,
    init_attn_backend,
    init_kv_cache,
)
from vllm.v1.worker.gpu.block_table import BlockTables
56
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
57
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
58
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
59
from vllm.v1.worker.gpu.dp_utils import (
60
    get_cudagraph_and_dp_padding,
61
62
    make_num_tokens_across_dp,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
63
64
65
from vllm.v1.worker.gpu.input_batch import (
    InputBatch,
    InputBuffers,
66
    combine_sampled_and_draft_tokens,
67
    expand_idx_mapping,
68
    get_num_sampled_and_rejected,
69
    post_update,
70
    post_update_pool,
71
72
    prepare_pos_seq_lens,
    prepare_prefill_inputs,
Woosuk Kwon's avatar
Woosuk Kwon committed
73
)
74
75
76
77
78
from vllm.v1.worker.gpu.kv_connector import (
    NO_OP_KV_CONNECTOR,
    KVConnector,
    get_kv_connector,
)
79
from vllm.v1.worker.gpu.lora_utils import LoraState
80
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
81
from vllm.v1.worker.gpu.model_states import ModelState
82
from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner
83
from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive
84
from vllm.v1.worker.gpu.sample.output import SamplerOutput
85
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
86
from vllm.v1.worker.gpu.sample.sampler import Sampler
87
from vllm.v1.worker.gpu.spec_decode import init_speculator
88
89
90
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
    set_eagle3_aux_hidden_state_layers,
)
91
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
92
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
93
from vllm.v1.worker.gpu.states import RequestState
94
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
98
99
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

logger = init_logger(__name__)


100
class GPUModelRunner(LoRAModelRunnerMixin):
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config

        self.device = device
        self.dtype = self.model_config.dtype
        self.kv_cache_dtype = self.dtype
        if self.cache_config.cache_dtype != "auto":
            # Quantized KV cache.
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                self.cache_config.cache_dtype
            ]

        self.vocab_size = self.model_config.get_vocab_size()
        self.max_model_len = self.model_config.max_model_len
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
        self.max_num_reqs = self.scheduler_config.max_num_seqs
130

Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
134
        self.use_async_scheduling = self.scheduler_config.async_scheduling
        self.output_copy_stream = torch.cuda.Stream(self.device)
        self.output_copy_event = torch.cuda.Event()

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        # Pipeline parallelism.
        self.pp_size = self.parallel_config.pipeline_parallel_size
        self.use_pp = self.pp_size > 1
        if self.use_pp:
            self.is_first_pp_rank = get_pp_group().is_first_rank
            self.is_last_pp_rank = get_pp_group().is_last_rank
        else:
            self.is_first_pp_rank = True
            self.is_last_pp_rank = True

        # Decode context parallelism.
        self.dcp_size = self.parallel_config.decode_context_parallel_size
        self.use_dcp = self.dcp_size > 1
        self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
        self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size

151
152
153
154
155
156
157
158
159
        # Multimodal
        self.mm_registry = MULTIMODAL_REGISTRY
        self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
            self.model_config
        )
        self.encoder_cache = None
        if self.supports_mm_inputs and self.is_first_pp_rank:
            self.encoder_cache = EncoderCache()

160
        self.speculator = None
161
        self.num_speculative_steps = 0
162
        self.use_aux_hidden_state_outputs = False
163
164
        if self.speculative_config is not None:
            self.num_speculative_steps = self.speculative_config.num_speculative_tokens
165
166
167
168
169
170
171
172
173
174
175
176
            if self.is_last_pp_rank:
                self.speculator = init_speculator(self.vllm_config, self.device)

            if self.speculative_config.method == "eagle3":
                # EAGLE3 may require auxiliary hidden states from target model outputs.
                self.use_aux_hidden_state_outputs = True
                if self.pp_size > 1:
                    raise ValueError("EAGLE3 with pipeline parallel is not supported.")

        # Draft tokens propagation - for spec-dec + struct outputs.
        self.draft_tokens_handler = DraftTokensHandler(self.device)

Woosuk Kwon's avatar
Woosuk Kwon committed
177
178
179
180
        self.req_states = RequestState(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
181
            num_speculative_steps=self.num_speculative_steps,
Woosuk Kwon's avatar
Woosuk Kwon committed
182
183
184
185
186
187
188
189
            vocab_size=self.vocab_size,
            device=self.device,
        )
        self.input_buffers = InputBuffers(
            max_num_reqs=self.max_num_reqs,
            max_num_tokens=self.max_num_tokens,
            device=self.device,
        )
190
191
192
193
        self.sampler = Sampler(
            max_num_reqs=self.max_num_reqs,
            vocab_size=self.vocab_size,
            device=self.device,
194
            req_states=self.req_states,
195
            logprobs_mode=self.model_config.logprobs_mode,
196
            num_speculative_tokens=self.num_speculative_steps + 1,
197
        )
198
        self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
199
200

        # CUDA graphs.
201
        self.cudagraph_manager = CudaGraphManager(
202
203
204
            self.vllm_config,
            self.use_aux_hidden_state_outputs,
            self.device,
205
        )
206
207
208
209
        # Structured outputs worker.
        self.structured_outputs_worker = StructuredOutputsWorker(
            max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
            vocab_size=self.vocab_size,
210
            device=self.device,
211
        )
212
213
        # LoRA-related workers.
        self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
214
        # KV Connector if configured.
215
216
        self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR

217
218
219
220
        # Pooling models.
        self.is_pooling_model = self.model_config.runner_type == "pooling"
        self.pooling_runner: PoolingRunner | None = None

221
222
223
        # For transferring state from execute_model to subsequent sample_tokens call.
        self.execute_model_state: tuple | None = None

224
225
226
227
    def update_max_model_len(self, max_model_len: int) -> None:
        self.max_model_len = max_model_len
        self.req_states.max_model_len = max_model_len

228
229
230
231
232
233
234
    def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
        tasks: list[SupportedTask] = []
        if self.model_config.runner_type == "generate":
            tasks.append("generate")
        if self.pooling_runner is not None:
            tasks.extend(self.pooling_runner.get_supported_pooling_tasks())
        return tuple(tasks)
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
238
239
240
241
242
243
244
245
246
247

    def load_model(self, *args, **kwargs) -> None:
        time_before_load = time.perf_counter()
        with DeviceMemoryProfiler() as m:
            model_loader = get_model_loader(self.vllm_config.load_config)
            logger.info("Loading model from scratch...")

            self.model = model_loader.load_model(
                vllm_config=self.vllm_config,
                model_config=self.vllm_config.model_config,
            )
            if self.lora_config:
                self.model = self.load_lora_model(
248
                    self.model, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
249
                )
250
251
252
253
254

            if self.use_aux_hidden_state_outputs:
                assert self.speculative_config is not None
                set_eagle3_aux_hidden_state_layers(self.model, self.speculative_config)
            if self.speculator is not None:
255
                self.speculator.load_model(self.model)
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
258
259
        time_after_load = time.perf_counter()

        self.model_memory_usage = m.consumed_memory
        logger.info(
260
261
            "Model loading took %s GiB and %.6f seconds",
            format_gib(m.consumed_memory),
Woosuk Kwon's avatar
Woosuk Kwon committed
262
263
264
            time_after_load - time_before_load,
        )

265
        prepare_communication_buffer_for_model(self.model)
266
267
        if self.speculator is not None:
            prepare_communication_buffer_for_model(self.speculator)
268

269
        # Initialize the components that require the model.
270
271
272
        self.model_state = ModelState(
            self.vllm_config, self.model, self.encoder_cache, self.device
        )
273
274
        if self.is_pooling_model:
            self.pooling_runner = PoolingRunner(self.model)
275

Woosuk Kwon's avatar
Woosuk Kwon committed
276
277
278
    def get_model(self) -> nn.Module:
        return self.model

279
280
281
282
283
    @functools.cached_property
    def main_stream(self) -> torch.cuda.Stream:
        # Cache the default CUDA stream to avoid lookup overhead.
        return torch.cuda.current_stream(self.device)

Woosuk Kwon's avatar
Woosuk Kwon committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    def get_kv_cache_spec(self):
        return get_kv_cache_spec(self.vllm_config)

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        kv_cache_config = deepcopy(kv_cache_config)
        self.kv_cache_config = kv_cache_config
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
        ]

        self.block_tables = BlockTables(
            block_sizes=block_sizes,
            max_num_reqs=self.max_num_reqs,
            max_num_batched_tokens=self.max_num_tokens,
            max_model_len=self.max_model_len,
            device=self.device,
301
302
303
            cp_size=self.dcp_size,
            cp_rank=self.dcp_rank,
            cp_interleave=self.cp_interleave,
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
        )

306
        self.attn_backends, self.attn_groups = init_attn_backend(
307
            self.kv_cache_config, self.vllm_config, self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
308
        )
309
        check_attention_cp_compatibility(self.vllm_config)
310
        if self.speculator is not None:
311
312
313
            # HACK(woosuk)
            self.speculator.set_attn(
                self.kv_cache_config,
314
                self.attn_groups,
315
316
                self.block_tables,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
317
318

        self.kv_caches: list[torch.Tensor] = []
319
        kv_caches_dict = init_kv_cache(
Woosuk Kwon's avatar
Woosuk Kwon committed
320
321
322
323
324
325
            self.kv_caches,
            self.compilation_config.static_forward_context,
            self.kv_cache_config,
            self.attn_backends,
            self.device,
        )
326
327
        self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)

Woosuk Kwon's avatar
Woosuk Kwon committed
328
329
    @torch.inference_mode()
    def _dummy_run(
330
        self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
331
    ) -> tuple[torch.Tensor | None, torch.Tensor | None]:
332
        # Create a dummy scheduler output.
Woosuk Kwon's avatar
Woosuk Kwon committed
333
        num_reqs = min(num_tokens, self.max_num_reqs)
334
335
336
337
        num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
        num_tokens_per_request[-1] += num_tokens % num_reqs
        assert sum(num_tokens_per_request) == num_tokens
        num_scheduled_tokens = {
338
            f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
339
340
341
342
343
        }
        dummy_scheduler_output = SchedulerOutput.make_empty()
        dummy_scheduler_output.total_num_scheduled_tokens = num_tokens
        dummy_scheduler_output.num_scheduled_tokens = num_scheduled_tokens

344
345
346
        # Disable any use of KVConnector for dummy runs.
        self.kv_connector.set_disabled(True)

347
348
        # For non-first PP ranks, create dummy intermediate_tensors.
        intermediate_tensors = None
349
        if not self.is_first_pp_rank:
350
351
352
353
354
355
            intermediate_tensors = self.model.make_empty_intermediate_tensors(
                batch_size=num_tokens,
                dtype=self.model_config.dtype,
                device=self.device,
            )

356
357
        # Execute the model.
        self.execute_model(
358
359
360
361
            dummy_scheduler_output,
            intermediate_tensors=intermediate_tensors,
            dummy_run=True,
            skip_attn_for_dummy_run=skip_attn,
362
        )
363
        self.kv_connector.set_disabled(False)
364
365

        # Non-last PP ranks don't produce output for sampling.
366
        if not self.is_last_pp_rank:
367
368
            return None, None

369
        assert self.execute_model_state is not None
370
        input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state
371
        self.execute_model_state = None
372
        assert hidden_states is not None  # Last PP rank always has hidden_states
373
        sample_hidden_states = hidden_states[input_batch.logits_indices]
Woosuk Kwon's avatar
Woosuk Kwon committed
374
375
376
        return hidden_states, sample_hidden_states

    @torch.inference_mode()
377
    def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
378
379
        num_reqs = hidden_states.shape[0]
        logits = self.model.compute_logits(hidden_states)
380
381
382
        idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
        idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
        pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
383
384
385
386
        dummy_input_ids = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
        expanded_local_pos = torch.zeros(
            num_reqs, dtype=torch.int32, device=self.device
        )
387
388
389
        # NOTE(woosuk): During the initial memory profiling, the sampler may skip
        # top_k, top_p, and logprobs, using less GPU memory than what is possible
        # during actual execution.
390
391
392
393
394
395
396
397
398
        self.sampler(
            logits,
            idx_mapping,
            idx_mapping_np,
            idx_mapping_np,
            pos,
            dummy_input_ids,
            expanded_local_pos,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
399

400
401
402
403
404
    @torch.inference_mode()
    def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
        assert self.pooling_runner is not None
        self.pooling_runner.dummy_pooler_run(hidden_states)

Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
407
    @torch.inference_mode()
    def profile_run(self) -> None:
        hidden_states, sample_hidden_states = self._dummy_run(
408
            self.max_num_tokens, skip_attn=True
Woosuk Kwon's avatar
Woosuk Kwon committed
409
        )
410

411
        # Only run sampler/pooler on last PP rank (non-last ranks return None).
412
        if self.is_last_pp_rank:
413
            assert sample_hidden_states is not None
414
415
416
417
            if self.pooling_runner is None:
                self._dummy_sampler_run(sample_hidden_states)
            else:
                self._dummy_pooler_run(hidden_states)
418

419
            if self.speculator is not None:
420
421
422
423
424
425
426
427
428
429
                num_tokens_across_dp = make_num_tokens_across_dp(
                    self.parallel_config.data_parallel_size, self.max_num_tokens
                )
                self.speculator.run_model(
                    self.max_num_tokens,
                    attn_metadata=None,
                    slot_mappings=None,
                    num_tokens_across_dp=num_tokens_across_dp,
                )

Woosuk Kwon's avatar
Woosuk Kwon committed
430
431
432
433
434
        torch.cuda.synchronize()
        del hidden_states, sample_hidden_states
        gc.collect()

    def reset_mm_cache(self) -> None:
435
436
        if self.encoder_cache is not None:
            self.encoder_cache.reset_mm_cache()
437
438

    def reset_encoder_cache(self) -> None:
439
440
        if self.encoder_cache is not None:
            self.encoder_cache.reset_encoder_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
441
442
443
444
445
446
447
448
449
450
451
452
453
454

    def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
        # SP is not supported yet.
        return num_scheduled_tokens

    @torch.inference_mode()
    def capture_model(self) -> int:
        if not self.cudagraph_manager.needs_capture():
            logger.warning(
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
                "ensure `cudagraph_mode` was not manually set to `NONE`"
            )
            return 0

455
456
457
458
459
460
461
462
        # TODO (zhanqiu): support CUDA graph for PP.
        if self.use_pp:
            logger.warning_once(
                "Skipping CUDA graph capture because pipeline parallel is "
                "enabled. Pipeline parallel is currently eager-only.",
            )
            return 0

Woosuk Kwon's avatar
Woosuk Kwon committed
463
        start_time = time.perf_counter()
464
        gc.collect()
465
        torch.cuda.empty_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
466
467
468
469
470
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

        with self.maybe_setup_dummy_loras(self.lora_config):
            self.cudagraph_manager.capture(
                model=self.model,
471
                model_state=self.model_state,
Woosuk Kwon's avatar
Woosuk Kwon committed
472
473
                input_buffers=self.input_buffers,
                block_tables=self.block_tables,
474
                attn_groups=self.attn_groups,
Woosuk Kwon's avatar
Woosuk Kwon committed
475
                kv_cache_config=self.kv_cache_config,
476
                has_lora=self.lora_config is not None,
Woosuk Kwon's avatar
Woosuk Kwon committed
477
            )
478
            if self.speculator is not None:
479
                self.speculator.capture_model()
Woosuk Kwon's avatar
Woosuk Kwon committed
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info(
            "Graph capturing finished in %.0f secs, took %.2f GiB",
            elapsed_time,
            cuda_graph_size / (1 << 30),
        )
        return cuda_graph_size

    def warmup_for_prefill(self) -> None:
        # For FlashInfer, we would like to execute a dummy prefill run
        # to trigger JIT compilation.
        if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
            self._dummy_run(self.max_num_tokens, skip_attn=False)
            torch.cuda.synchronize()

500
    def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
501
        finished_req_ids = scheduler_output.finished_req_ids
502
503
504
        preempted_req_ids = scheduler_output.preempted_req_ids
        if preempted_req_ids:
            finished_req_ids = finished_req_ids.union(preempted_req_ids)
505
        for req_id in finished_req_ids:
Woosuk Kwon's avatar
Woosuk Kwon committed
506
            self.req_states.remove_request(req_id)
507
508
            if self.encoder_cache is not None:
                self.encoder_cache.remove_request(req_id)
509
            self.prompt_logprobs_worker.remove_request(req_id)
510
            self.lora_state.remove_request(req_id)
511

512
    def free_states(self, scheduler_output: SchedulerOutput) -> None:
513
        if self.encoder_cache is not None:
514
            for mm_hash in scheduler_output.free_encoder_mm_hashes:
515
                self.encoder_cache.free_encoder_cache(mm_hash)
Woosuk Kwon's avatar
Woosuk Kwon committed
516

517
    def add_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
518
        for new_req_data in scheduler_output.scheduled_new_reqs:
519
520
            assert new_req_data.prompt_token_ids is not None
            assert new_req_data.prefill_token_ids is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
521
            req_id = new_req_data.req_id
522
            prompt_len = len(new_req_data.prompt_token_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
523
524
            self.req_states.add_request(
                req_id=req_id,
525
                prompt_len=prompt_len,
526
                all_token_ids=new_req_data.prefill_token_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
527
528
529
                num_computed_tokens=new_req_data.num_computed_tokens,
            )
            req_index = self.req_states.req_id_to_index[req_id]
530

531
532
            if self.encoder_cache is not None:
                self.encoder_cache.add_request(req_id, new_req_data.mm_features)
533

534
            self.model_state.add_request(req_index, new_req_data)
535
536
537
            self.block_tables.append_block_ids(
                req_index, new_req_data.block_ids, overwrite=True
            )
538
            self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
Woosuk Kwon's avatar
Woosuk Kwon committed
539

540
541
542
543
544
545
546
547
            if new_req_data.sampling_params is not None:
                self.sampler.add_request(
                    req_index, prompt_len, new_req_data.sampling_params
                )
                self.prompt_logprobs_worker.add_request(
                    req_id, req_index, new_req_data.sampling_params
                )

548
549
        if scheduler_output.scheduled_new_reqs:
            self.req_states.apply_staged_writes()
550
            self.sampler.apply_staged_writes()
551
            self.model_state.apply_staged_writes()
552
553

    def update_requests(self, scheduler_output: SchedulerOutput) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
554
        # Add new blocks for the existing requests.
555
556
        reqs = scheduler_output.scheduled_cached_reqs
        for req_new_block_ids, req_id in zip(reqs.new_block_ids, reqs.req_ids):
Woosuk Kwon's avatar
Woosuk Kwon committed
557
            if req_new_block_ids is not None:
558
                req_index = self.req_states.req_id_to_index[req_id]
559
560
561
                self.block_tables.append_block_ids(
                    req_index, req_new_block_ids, overwrite=False
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
562
563

    def prepare_inputs(
564
        self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
Woosuk Kwon's avatar
Woosuk Kwon committed
565
566
567
    ) -> InputBatch:
        num_tokens = scheduler_output.total_num_scheduled_tokens
        assert num_tokens > 0
568
569
        num_tokens_per_req = scheduler_output.num_scheduled_tokens
        num_reqs = len(num_tokens_per_req)
Woosuk Kwon's avatar
Woosuk Kwon committed
570
571
572

        # Decode first, then prefill.
        # batch_idx -> req_id
573
574
575
        req_ids = sorted(num_tokens_per_req, key=num_tokens_per_req.get)  # type: ignore[arg-type]
        numtoks_iter = map(num_tokens_per_req.get, req_ids)
        num_scheduled_tokens = np.fromiter(numtoks_iter, dtype=np.int32, count=num_reqs)
Woosuk Kwon's avatar
Woosuk Kwon committed
576

577
578
        idx_mapping_iter = map(self.req_states.req_id_to_index.get, req_ids)
        idx_mapping_np = np.fromiter(idx_mapping_iter, dtype=np.int32, count=num_reqs)
579
        idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
Woosuk Kwon's avatar
Woosuk Kwon committed
580

581
        # Get the number of draft tokens for each request.
582
583
        draft_tokens = scheduler_output.scheduled_spec_decode_tokens
        if not draft_tokens:
584
585
586
            # No draft token scheduled (common case).
            total_num_draft_tokens = 0
            total_num_logits = num_reqs
587
            cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
588
589
590
            cu_num_logits = torch.arange(
                num_reqs + 1, device=self.device, dtype=torch.int32
            )
591
            expanded_idx_mapping = idx_mapping
592
593
594
            expanded_local_pos = torch.zeros(
                num_reqs, dtype=torch.int32, device=self.device
            )
595
596
        else:
            num_draft_tokens = np.array(
597
                [len(draft_tokens.get(req_id, ())) for req_id in req_ids],
598
599
600
601
602
                dtype=np.int32,
            )
            total_num_draft_tokens = int(num_draft_tokens.sum())
            total_num_logits = num_reqs + total_num_draft_tokens

603
604
605
606
            num_logits = num_draft_tokens + 1
            cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32)
            cu_num_logits_np[0] = 0
            np.cumsum(num_logits, out=cu_num_logits_np[1:])
607
            cu_num_logits = async_copy_to_gpu(cu_num_logits_np, device=self.device)
608

609
            max_expand_len = self.num_speculative_steps + 1
610
            expanded_idx_mapping, expanded_local_pos = expand_idx_mapping(
611
                idx_mapping, total_num_logits, cu_num_logits, max_expand_len
612
613
            )

614
        # Get query_start_loc.
615
616
617
        query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
        query_start_loc_np[0] = 0
        np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
618
619
        # Pad for full CUDA graph mode.
        # Some attention backends like FA3 require query_start_loc to be non-decreasing.
620
        query_start_loc_np[num_reqs + 1 :] = num_tokens
621
        async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
622
623
        query_start_loc_np = query_start_loc_np[: num_reqs + 1]
        query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
624

625
626
627
628
629
630
631
632
633
634
635
        # Get prefill tokens if any.
        if self.req_states.any_prefills(idx_mapping_np):
            prepare_prefill_inputs(
                self.input_buffers.input_ids,
                self.req_states.next_prefill_tokens,
                idx_mapping,
                query_start_loc,
                self.req_states.all_token_ids.gpu,
                self.req_states.prefill_len.gpu,
                self.req_states.num_computed_tokens.gpu,
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
636

637
638
639
        # Prepare positions and seq_lens.
        prepare_pos_seq_lens(
            idx_mapping,
640
641
            query_start_loc,
            self.req_states.num_computed_tokens.gpu,
642
643
644
645
646
            self.input_buffers.positions,
            self.input_buffers.seq_lens,
        )
        seq_lens = self.input_buffers.seq_lens[:num_reqs]

647
        dcp_local_seq_lens = None
648
649
        if self.use_dcp:
            # Prepare dcp local seq_lens.
650
651
            prepare_dcp_local_seq_lens(
                self.input_buffers.dcp_local_seq_lens,
652
                self.input_buffers.seq_lens,
653
                num_reqs,
654
655
656
                self.dcp_size,
                self.dcp_rank,
                self.cp_interleave,
657
            )
658
            dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
659

660
        # Some input token ids are directly read from the last sampled tokens
661
662
        # and draft tokens. Also, get the logits indices to sample tokens from.
        logits_indices = combine_sampled_and_draft_tokens(
663
            self.input_buffers.input_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
664
665
            idx_mapping,
            self.req_states.last_sampled_tokens,
666
            query_start_loc,
667
668
            seq_lens,
            self.req_states.prefill_len.gpu,
669
670
671
            self.req_states.draft_tokens,
            cu_num_logits,
            total_num_logits,
Woosuk Kwon's avatar
Woosuk Kwon committed
672
673
674
675
676
677
678
        )

        return InputBatch(
            req_ids=req_ids,
            num_reqs=num_reqs,
            idx_mapping=idx_mapping,
            idx_mapping_np=idx_mapping_np,
679
            expanded_idx_mapping=expanded_idx_mapping,
680
            expanded_local_pos=expanded_local_pos,
Woosuk Kwon's avatar
Woosuk Kwon committed
681
682
683
            num_scheduled_tokens=num_scheduled_tokens,
            num_tokens=num_tokens,
            num_tokens_after_padding=num_tokens_after_padding,
684
            num_draft_tokens=total_num_draft_tokens,
685
            query_start_loc=query_start_loc,
Woosuk Kwon's avatar
Woosuk Kwon committed
686
            query_start_loc_np=query_start_loc_np,
687
            seq_lens=seq_lens,
688
689
690
            dcp_local_seq_lens=dcp_local_seq_lens,
            input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
            positions=self.input_buffers.positions[:num_tokens_after_padding],
Woosuk Kwon's avatar
Woosuk Kwon committed
691
            logits_indices=logits_indices,
692
            cu_num_logits=cu_num_logits,
693
            cu_num_logits_np=cu_num_logits_np,
694
            has_structured_output_reqs=scheduler_output.has_structured_output_requests,
Woosuk Kwon's avatar
Woosuk Kwon committed
695
696
        )

697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
    def prepare_attn(
        self, input_batch: InputBatch
    ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
        # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
        block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
        # Compute slot mappings: [num_kv_cache_groups, num_tokens]
        slot_mappings = self.block_tables.compute_slot_mappings(
            input_batch.idx_mapping,
            input_batch.query_start_loc,
            input_batch.positions,
        )
        return block_tables, slot_mappings

    def prepare_dummy_attn(
        self, input_batch: InputBatch
    ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
        block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
        slot_mappings = self.block_tables.get_dummy_slot_mappings(
            input_batch.num_tokens
        )
        return block_tables, slot_mappings

Woosuk Kwon's avatar
Woosuk Kwon committed
719
720
721
722
723
    def sample(
        self,
        hidden_states: torch.Tensor,
        input_batch: InputBatch,
        grammar_output: GrammarOutput | None,
724
    ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
725
        sample_hidden_states = hidden_states[input_batch.logits_indices]
726
        sample_pos = input_batch.positions[input_batch.logits_indices]
727
        input_ids = input_batch.input_ids[input_batch.logits_indices]
Woosuk Kwon's avatar
Woosuk Kwon committed
728
729
730
        logits = self.model.compute_logits(sample_hidden_states)
        if grammar_output is not None:
            # Apply grammar bitmask to the logits in-place.
731
732
733
734
735
736
            self.structured_outputs_worker.apply_grammar_bitmask(
                logits,
                input_batch,
                grammar_output.structured_output_request_ids,
                grammar_output.grammar_bitmask,
            )
737

738
        # Sample tokens and compute logprobs (if needed).
739
740
741
742
        sampler_output = self.sampler(
            logits,
            input_batch.expanded_idx_mapping,
            input_batch.idx_mapping_np,
743
            input_batch.cu_num_logits_np,
744
            sample_pos,
745
746
            input_ids,
            input_batch.expanded_local_pos,
747
        )
748
749
750

        if input_batch.num_draft_tokens == 0:
            # No draft tokens (common case).
751
752
753
            num_sampled = torch.ones(
                input_batch.num_reqs, dtype=torch.int32, device=self.device
            )
754
        else:
755
            # Rejection sampling for spec decoding.
756
757
758
759
760
761
762
            sampled_tokens, num_sampled = rejection_sample(
                sampler_output.sampled_token_ids,
                input_ids,
                input_batch.cu_num_logits,
                self.num_speculative_steps,
            )
            sampler_output.sampled_token_ids = sampled_tokens
763
764
765
766
767
768
769
770
771
772

        # Get the number of sampled and rejected tokens.
        # For chunked prefills, num_sampled and num_rejected are both 0.
        num_sampled, num_rejected = get_num_sampled_and_rejected(
            num_sampled,
            input_batch.seq_lens,
            input_batch.cu_num_logits,
            input_batch.idx_mapping,
            self.req_states.prefill_len.gpu,
        )
773
        return sampler_output, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
774
775
776
777

    def postprocess(
        self,
        input_batch: InputBatch,
778
779
        sampled_tokens: torch.Tensor,
        num_sampled: torch.Tensor,
780
        num_rejected: torch.Tensor,
781
782
    ) -> None:
        # Update the number of computed tokens.
783
        post_update(
784
            input_batch.idx_mapping,
785
            self.req_states.num_computed_tokens.gpu,
786
            self.req_states.last_sampled_tokens,
787
            self.sampler.penalties_state.output_bin_counts,
788
789
            sampled_tokens,
            num_sampled,
790
            num_rejected,
791
            input_batch.query_start_loc,
792
793
            self.req_states.all_token_ids.gpu,
            self.req_states.total_len.gpu,
Woosuk Kwon's avatar
Woosuk Kwon committed
794
        )
795
796

        # Update the number of computed prefill tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
797
        idx_mapping_np = input_batch.idx_mapping_np
798
        computed_prefill = self.req_states.num_computed_prefill_tokens
799
800
801
        computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
        np.minimum(
            computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
Woosuk Kwon's avatar
Woosuk Kwon committed
802
803
804
805
806
807
        )

    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: SchedulerOutput,
808
        intermediate_tensors: IntermediateTensors | None = None,
Woosuk Kwon's avatar
Woosuk Kwon committed
809
        dummy_run: bool = False,
810
        skip_attn_for_dummy_run: bool = False,
811
    ) -> ModelRunnerOutput | IntermediateTensors | None:
812
813
814
815
816
817
818
819
820
        if not dummy_run:
            # Update the request states.
            self.finish_requests(scheduler_output)
            self.free_states(scheduler_output)
            self.add_requests(scheduler_output)
            self.update_requests(scheduler_output)
            self.block_tables.apply_staged_writes()
            if scheduler_output.total_num_scheduled_tokens == 0:
                # No need to run the model.
821
822
                empty_output = self.kv_connector.no_forward(scheduler_output)
                return empty_output
Woosuk Kwon's avatar
Woosuk Kwon committed
823

824
825
826
827
828
829
830
        # Get local cudagraph mode and size.
        local_cudagraph_mode, local_cudagraph_size = (
            self.cudagraph_manager.get_cudagraph_runtime_mode(
                num_reqs=len(scheduler_output.num_scheduled_tokens),
                num_tokens=scheduler_output.total_num_scheduled_tokens,
                max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
            )
831
        )
832
833
834

        # DP sync: num_tokens + cudagraph_size + cudagraph_mode
        num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
835
            get_cudagraph_and_dp_padding(
836
                scheduler_output.total_num_scheduled_tokens,
837
838
                local_cudagraph_size,
                local_cudagraph_mode.value,
839
840
                self.parallel_config.data_parallel_size,
                self.parallel_config.data_parallel_rank,
841
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
842
        )
843
        cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
844
845
        if num_tokens_after_padding == 0:
            # All DP ranks have zero tokens to run.
846
847
            empty_output = self.kv_connector.no_forward(scheduler_output)
            return empty_output
848
849
850
851
852

        if not dummy_run:
            # Common case.
            # Prepare all the inputs and copy to the input buffers.
            input_batch = self.prepare_inputs(
853
                scheduler_output, num_tokens_after_padding
854
            )
855
856
            block_tables, slot_mappings = self.prepare_attn(input_batch)

857
858
            if self.lora_config:
                # Activate LoRA adapters.
859
                lora_inputs = self.lora_state.make_lora_inputs(
860
861
862
                    input_batch.req_ids,
                    input_batch.idx_mapping_np,
                    input_batch.num_scheduled_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
863
                )
864
865
                self._set_active_loras(*lora_inputs)
        else:
866
            # No actual tokens to run. A dummy run for DP or memory profiling.
867
868
869
870
871
872
873
            num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
            input_batch = InputBatch.make_dummy(
                num_reqs=num_reqs,
                num_tokens=num_tokens_after_padding,
                input_buffers=self.input_buffers,
                device=self.device,
            )
874
            if not skip_attn_for_dummy_run:
875
876
877
878
                block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
            else:
                block_tables = None
                slot_mappings = None
879
            # FIXME(woosuk): Fix warmup for LoRA.
Woosuk Kwon's avatar
Woosuk Kwon committed
880

881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
        attn_metadata = None
        slot_mappings_by_layer = None
        if not (dummy_run and skip_attn_for_dummy_run):
            assert slot_mappings is not None
            slot_mappings_by_layer = build_slot_mappings_by_layer(
                slot_mappings, self.kv_cache_config
            )
            assert block_tables is not None
            attn_metadata = self.model_state.prepare_attn(
                input_batch,
                block_tables,
                slot_mappings,
                self.attn_groups,
                self.kv_cache_config,
            )

897
898
899
900
901
902
903
904
905
906
        inputs_embeds = None
        if self.supports_mm_inputs and self.is_first_pp_rank and not dummy_run:
            # Run MM encoder (if needed) and get multimodal embeddings.
            # Only first PP rank prepares multimodal embeddings.
            inputs_embeds = self.model_state.get_mm_embeddings(
                scheduler_output.scheduled_encoder_inputs,
                input_batch,
                self.req_states,
            )

907
908
909
        model_inputs = {
            "input_ids": input_batch.input_ids,
            "positions": input_batch.positions,
910
            "inputs_embeds": inputs_embeds,
911
912
913
914
915
916
917
918
919
920
            # NOTE: Values returned by `prepare_inputs` will override the default
            # values above.
            **self.model_state.prepare_inputs(input_batch, self.req_states),
        }
        if not self.is_first_pp_rank:
            # Update for non-first PP ranks.
            model_inputs["input_ids"] = None
            model_inputs["inputs_embeds"] = None
            model_inputs["intermediate_tensors"] = intermediate_tensors

Woosuk Kwon's avatar
Woosuk Kwon committed
921
        # Run model.
922
923
        if cudagraph_runtime_mode == CUDAGraphMode.FULL:
            # Use explicit cudagraph replay for FULL mode.
Woosuk Kwon's avatar
Woosuk Kwon committed
924
925
            # NOTE(woosuk): Here, we don't need to pass the input tensors,
            # because they are already copied to the CUDA graph input buffers.
926
            self.kv_connector.pre_forward(scheduler_output)
927
            model_output = self.cudagraph_manager.run_fullgraph(
Woosuk Kwon's avatar
Woosuk Kwon committed
928
929
                input_batch.num_tokens_after_padding
            )
930
931
932
933
934
            if self.use_aux_hidden_state_outputs:
                hidden_states, aux_hidden_states = model_output
            else:
                hidden_states = model_output
                aux_hidden_states = None
Woosuk Kwon's avatar
Woosuk Kwon committed
935
        else:
936
937
938
939
940
941
            # For piecewise and eager mode, just call model().
            batch_descriptor = BatchDescriptor(
                num_tokens=input_batch.num_tokens_after_padding,
                has_lora=self.lora_config is not None,
            )

Woosuk Kwon's avatar
Woosuk Kwon committed
942
            with set_forward_context(
943
                attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
944
945
                self.vllm_config,
                num_tokens=input_batch.num_tokens_after_padding,
946
                cudagraph_runtime_mode=cudagraph_runtime_mode,
Woosuk Kwon's avatar
Woosuk Kwon committed
947
                num_tokens_across_dp=num_tokens_across_dp,
948
                batch_descriptor=batch_descriptor,
949
                slot_mapping=slot_mappings_by_layer,
Woosuk Kwon's avatar
Woosuk Kwon committed
950
            ):
951
                self.kv_connector.pre_forward(scheduler_output)
952
                model_output = self.model(**model_inputs)
953
954
955
956
957
                if self.use_aux_hidden_state_outputs:
                    hidden_states, aux_hidden_states = model_output
                else:
                    hidden_states = model_output
                    aux_hidden_states = None
Woosuk Kwon's avatar
Woosuk Kwon committed
958

959
        kv_connector_output = self.kv_connector.post_forward(scheduler_output)
960
961
962
963
964
965
966
967
968
        self.execute_model_state = (
            input_batch,
            model_inputs,
            attn_metadata,
            slot_mappings_by_layer,
            hidden_states,
            aux_hidden_states,
            kv_connector_output,
        )
969

970
        if not self.is_last_pp_rank:
971
972
973
974
975
            # Non-last PP rank: return IntermediateTensors for sending.
            assert isinstance(hidden_states, IntermediateTensors)
            hidden_states.kv_connector_output = kv_connector_output
            return hidden_states
        # Last rank (or no PP): hidden_states is a tensor for sampling.
976
        assert isinstance(hidden_states, torch.Tensor)
Woosuk Kwon's avatar
Woosuk Kwon committed
977
978
979
980
        return None

    @torch.inference_mode()
    def sample_tokens(
981
        self, grammar_output: GrammarOutput | None
982
    ) -> AsyncOutput | ModelRunnerOutput | None:
983
984
985
        if self.execute_model_state is None:
            # The prior execute_model call must have failed.
            return None
986
987
988
989
990
991
992
993
994
        (
            input_batch,
            model_inputs,
            attn_metadata,
            slot_mappings_by_layer,
            hidden_states,
            aux_hidden_states,
            kv_connector_output,
        ) = self.execute_model_state
995
        self.execute_model_state = None
Woosuk Kwon's avatar
Woosuk Kwon committed
996

997
        if not self.is_last_pp_rank:
998
999
1000
1001
            # Non-last PP rank: hidden_states is None because this rank produced
            # IntermediateTensors instead of final hidden states. Receive the
            # sampled tokens broadcast from the last rank and update local state.
            sampled, num_sampled, num_rejected = pp_receive(
1002
                input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
1003
            )
1004
            self.postprocess(input_batch, sampled, num_sampled, num_rejected)
1005
1006
1007
            return None

        # Last rank: sample tokens
1008
        sampler_output, num_sampled, num_rejected = self.sample(
1009
            hidden_states, input_batch, grammar_output
Woosuk Kwon's avatar
Woosuk Kwon committed
1010
        )
1011
1012

        if self.use_pp:
1013
            # Broadcast to non-last PP ranks (handles spec decode multi-token).
1014
            pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
1015

1016
1017
1018
1019
        prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
            self.model.compute_logits,
            hidden_states,
            input_batch,
1020
            self.req_states.all_token_ids.gpu,
1021
            self.req_states.num_computed_tokens.gpu,
1022
            self.req_states.prompt_len.np,
1023
1024
1025
            self.req_states.prefill_len.np,
            self.req_states.num_computed_prefill_tokens,
        )
1026
1027
1028
1029
1030
1031
1032
1033

        # Prepare the model runner output.
        model_runner_output = ModelRunnerOutput(
            req_ids=input_batch.req_ids,
            # NOTE(woosuk): req_id_to_index is unused in this model runner.
            # Only for compatibility with the existing model runner and scheduler.
            req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
            sampled_token_ids=None,  # type: ignore
1034
            prompt_logprobs_dict=prompt_logprobs_dict,  # type: ignore[arg-type]
1035
            kv_connector_output=kv_connector_output,
1036
1037
1038
1039
        )
        async_output = AsyncOutput(
            model_runner_output=model_runner_output,
            sampler_output=sampler_output,
1040
            num_sampled_tokens=num_sampled,
1041
            main_stream=self.main_stream,
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
            copy_stream=self.output_copy_stream,
            copy_event=self.output_copy_event,
        )

        # Postprocess results and update request states.
        # NOTE: This is intentionally done after creating the AsyncOutput,
        # ensuring that `copy_event` is recorded before calling postprocess.
        # This sequencing may slightly reduce latency as async D2H copy does not
        # need to wait for the postprocess to finish.
        self.postprocess(
1052
            input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
Woosuk Kwon's avatar
Woosuk Kwon committed
1053
        )
1054
        if self.speculator is not None:
1055
            draft_tokens = self.speculator.propose(
1056
                input_batch,
1057
1058
                attn_metadata,
                slot_mappings_by_layer,
1059
                hidden_states,
1060
                aux_hidden_states,
1061
1062
                num_sampled,
                num_rejected,
1063
1064
1065
1066
                self.req_states.last_sampled_tokens,
                self.req_states.next_prefill_tokens,
                self.sampler.sampling_states.temperature.gpu,
                self.sampler.sampling_states.seeds.gpu,
1067
            )
1068
            self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
1069
            self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
1070
1071
1072
1073

        if self.use_async_scheduling:
            return async_output
        return async_output.get_output()
1074
1075
1076

    def take_draft_token_ids(self) -> DraftTokenIds | None:
        return self.draft_tokens_handler.get_draft_tokens()
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131

    @torch.inference_mode()
    def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
        if self.execute_model_state is None:
            # The prior execute_model call must have failed.
            return None

        input_batch, _, _, _, hidden_states, _, kv_connector_output = (
            self.execute_model_state
        )
        self.execute_model_state = None

        if not self.is_last_pp_rank:
            self.postprocess_pool(input_batch)
            return None

        assert self.pooling_runner is not None
        pooler_output, is_valid = self.pooling_runner.pool(
            hidden_states, input_batch, self.req_states
        )
        self.postprocess_pool(input_batch)

        # Build the model runner output.
        model_runner_output = ModelRunnerOutput(
            req_ids=input_batch.req_ids,
            req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
            kv_connector_output=kv_connector_output,
        )
        async_output = AsyncPoolingOutput(
            model_runner_output=model_runner_output,
            pooler_output=pooler_output,
            is_valid=is_valid,
            main_stream=self.main_stream,
            copy_stream=self.output_copy_stream,
            copy_event=self.output_copy_event,
        )
        if self.use_async_scheduling:
            return async_output
        return async_output.get_output()

    def postprocess_pool(self, input_batch: InputBatch) -> None:
        # Update the number of computed tokens.
        post_update_pool(
            input_batch.idx_mapping,
            self.req_states.num_computed_tokens.gpu,
            input_batch.query_start_loc,
        )

        # Update the number of computed prefill tokens.
        idx_mapping_np = input_batch.idx_mapping_np
        computed_prefill = self.req_states.num_computed_prefill_tokens
        computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
        np.minimum(
            computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
        )