model_runner.py 40.1 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from copy import deepcopy
from typing import Any

import numpy as np
import torch
import torch.nn as nn

from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import DeviceMemoryProfiler
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import (
    EMPTY_MODEL_RUNNER_OUTPUT,
    LogprobsTensors,
    ModelRunnerOutput,
)
from vllm.v1.sample.sampler import SamplerOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
from vllm.v1.worker.gpu.attn_utils import (
    build_attn_metadata,
    get_kv_cache_spec,
    init_attn_backend,
    init_kv_cache,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp
from vllm.v1.worker.gpu.input_batch import (
    InputBatch,
    InputBuffers,
42
    combine_sampled_and_draft_tokens,
43
    post_update,
44
45
    prepare_pos_seq_lens,
    prepare_prefill_inputs,
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
48
from vllm.v1.worker.gpu.spec_decode import init_speculator
49
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

logger = init_logger(__name__)


class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config
        self.cache_config = vllm_config.cache_config
        self.compilation_config = vllm_config.compilation_config
        self.lora_config = vllm_config.lora_config
        self.load_config = vllm_config.load_config
        self.parallel_config = vllm_config.parallel_config
        self.scheduler_config = vllm_config.scheduler_config
        self.speculative_config = vllm_config.speculative_config
        self.observability_config = vllm_config.observability_config

        self.device = device
        self.pin_memory = is_pin_memory_available()
        self.dtype = self.model_config.dtype
        self.kv_cache_dtype = self.dtype
        if self.cache_config.cache_dtype != "auto":
            # Quantized KV cache.
            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
                self.cache_config.cache_dtype
            ]
        self.is_pooling_model = False

        self.vocab_size = self.model_config.get_vocab_size()
        self.max_model_len = self.model_config.max_model_len
        self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
        self.max_num_reqs = self.scheduler_config.max_num_seqs
        self.hidden_size = self.model_config.get_hidden_size()

        self.dp_size = self.parallel_config.data_parallel_size
        self.dp_rank = self.parallel_config.data_parallel_rank

        self.use_async_scheduling = self.scheduler_config.async_scheduling
        self.output_copy_stream = torch.cuda.Stream(self.device)
        self.output_copy_event = torch.cuda.Event()
        if self.use_async_scheduling:
            self.input_prep_event = torch.cuda.Event()
            self.structured_outputs_event = torch.cuda.Event()
101
            self.spec_decode_event = torch.cuda.Event()
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
104
        else:
            self.input_prep_event = None
            self.structured_outputs_event = None
105
            self.spec_decode_event = None
Woosuk Kwon's avatar
Woosuk Kwon committed
106

107
108
109
        if self.speculative_config is not None:
            self.do_spec_decode = True
            self.num_speculative_steps = self.speculative_config.num_speculative_tokens
110
            self.speculator = init_speculator(self.vllm_config, self.device)
111
112
113
        else:
            self.do_spec_decode = False
            self.num_speculative_steps = 0
114
            self.speculator = None
115

Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
118
119
        self.req_states = RequestState(
            max_num_reqs=self.max_num_reqs,
            max_model_len=self.max_model_len,
            max_num_batched_tokens=self.max_num_tokens,
120
            num_speculative_steps=self.num_speculative_steps,
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            vocab_size=self.vocab_size,
            device=self.device,
            pin_memory=self.pin_memory,
        )
        self.input_buffers = InputBuffers(
            max_num_reqs=self.max_num_reqs,
            max_num_tokens=self.max_num_tokens,
            hidden_size=self.hidden_size,
            vocab_size=self.vocab_size,
            dtype=self.dtype,
            device=self.device,
            pin_memory=self.pin_memory,
        )
        self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)

        # CUDA graphs.
        self.cudagraph_manager = CudaGraphManager(
            vllm_config=self.vllm_config,
            device=self.device,
        )

    def get_supported_tasks(self) -> tuple[str]:
        return ("generate",)

    def load_model(self, *args, **kwargs) -> None:
        time_before_load = time.perf_counter()
        with DeviceMemoryProfiler() as m:
            model_loader = get_model_loader(self.vllm_config.load_config)
            logger.info("Loading model from scratch...")

            self.model = model_loader.load_model(
                vllm_config=self.vllm_config,
                model_config=self.vllm_config.model_config,
            )
            if self.lora_config:
                self.model = self.load_lora_model(
                    self.model,
                    self.vllm_config,
                    self.device,
                )
161
162
            if self.do_spec_decode:
                self.speculator.load_model(self.model)
Woosuk Kwon's avatar
Woosuk Kwon committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        time_after_load = time.perf_counter()

        self.model_memory_usage = m.consumed_memory
        logger.info(
            "Model loading took %.4f GiB and %.6f seconds",
            m.consumed_memory / GiB_bytes,
            time_after_load - time_before_load,
        )

    def get_model(self) -> nn.Module:
        return self.model

    def get_kv_cache_spec(self):
        return get_kv_cache_spec(self.vllm_config)

    def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
        kv_cache_config = deepcopy(kv_cache_config)
        self.kv_cache_config = kv_cache_config
        block_sizes = [
            kv_cache_group.kv_cache_spec.block_size
            for kv_cache_group in kv_cache_config.kv_cache_groups
        ]

        self.block_tables = BlockTables(
            block_sizes=block_sizes,
            max_num_reqs=self.max_num_reqs,
            max_num_batched_tokens=self.max_num_tokens,
            max_model_len=self.max_model_len,
            device=self.device,
            pin_memory=self.pin_memory,
        )

        self.attn_backends, self.attn_metadata_builders = init_attn_backend(
            self.kv_cache_config,
            self.vllm_config,
            self.device,
        )
200
201
202
        # TODO(woosuk): Support other backends.
        if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()):
            raise NotImplementedError("Only FLASH_ATTN backend is supported currently.")
Woosuk Kwon's avatar
Woosuk Kwon committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

        self.kv_caches: list[torch.Tensor] = []
        init_kv_cache(
            self.kv_caches,
            self.compilation_config.static_forward_context,
            self.kv_cache_config,
            self.attn_backends,
            self.device,
        )
        # Attention groups are not supported.
        self.attn_groups = []  # type: ignore

    def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
        block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
        slot_mappings = self.block_tables.get_dummy_slot_mappings(
            input_batch.num_tokens
        )
220
221
        num_computed_tokens = torch.zeros(
            input_batch.num_reqs, dtype=torch.int32, device=self.device
Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
224
225
226
227
228
        )
        attn_metadata = build_attn_metadata(
            attn_metadata_builders=self.attn_metadata_builders,
            num_reqs=input_batch.num_reqs,
            num_tokens=input_batch.num_tokens,
            query_start_loc=self.input_buffers.query_start_loc,
            seq_lens=self.input_buffers.seq_lens,
229
230
            seq_lens_np=input_batch.seq_lens_np,
            num_computed_tokens_cpu=num_computed_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
            block_tables=block_tables,
            slot_mappings=slot_mappings,
            kv_cache_config=self.kv_cache_config,
        )
        input_batch.attn_metadata = attn_metadata

    @torch.inference_mode()
    def _dummy_run(
        self,
        num_tokens: int,
        *args,
        skip_attn: bool = True,
        **kwargs,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        num_reqs = min(num_tokens, self.max_num_reqs)
        input_batch = InputBatch.make_dummy(
            num_reqs=num_reqs,
            num_tokens=num_tokens,
            input_buffers=self.input_buffers,
            device=self.device,
        )
        if not skip_attn:
            self.prepare_dummy_attn_metadata(input_batch)

        if self.dp_size == 1:
            num_tokens_across_dp: torch.Tensor | None = None
        else:
            num_tokens_across_dp = torch.full(
                (self.dp_size,), num_tokens, dtype=torch.int32, device="cpu"
            )
        num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
        with (
            self.maybe_dummy_run_with_lora(
                self.lora_config,
                input_batch.num_scheduled_tokens,
                num_sampled_tokens,
            ),
            set_forward_context(
                input_batch.attn_metadata,
                self.vllm_config,
                num_tokens=num_tokens,
                num_tokens_across_dp=num_tokens_across_dp,
            ),
        ):
            hidden_states = self.model(
                input_ids=input_batch.input_ids,
                positions=input_batch.positions,
            )
            sample_hidden_states = hidden_states[input_batch.logits_indices]
        return hidden_states, sample_hidden_states

    @torch.inference_mode()
    def _dummy_sampler_run(
        self,
        hidden_states: torch.Tensor,
    ) -> None:
        num_reqs = hidden_states.shape[0]
        sampling_metadata = SamplingMetadata.make_dummy(
            num_reqs=num_reqs,
            device=self.device,
        )
        logits = self.model.compute_logits(hidden_states)
        self.sampler(logits, sampling_metadata)

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    @torch.inference_mode()
    def _dummy_speculator_run(
        self,
        hidden_states: torch.Tensor,
        aux_hidden_states: list[torch.Tensor] | None,
    ) -> None:
        num_tokens = hidden_states.shape[0]
        num_reqs = min(num_tokens, self.max_num_reqs)
        input_batch = InputBatch.make_dummy(
            num_reqs=num_reqs,
            num_tokens=num_tokens,
            input_buffers=self.input_buffers,
            device=self.device,
        )
        sampling_metadata = SamplingMetadata.make_dummy(
            num_reqs=num_reqs,
            device=self.device,
        )
        num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
        self.propose_draft(
            input_batch=input_batch,
            sampling_metadata=sampling_metadata,
            last_hidden_states=hidden_states,
            aux_hidden_states=aux_hidden_states,
            num_sampled=num_sampled,
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
322
323
324
325
326
327
328
    @torch.inference_mode()
    def profile_run(self) -> None:
        hidden_states, sample_hidden_states = self._dummy_run(
            self.max_num_tokens,
            skip_attn=True,
        )
        self._dummy_sampler_run(sample_hidden_states)
329
330
        if self.do_spec_decode:
            self._dummy_speculator_run(hidden_states, None)
Woosuk Kwon's avatar
Woosuk Kwon committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        torch.cuda.synchronize()
        del hidden_states, sample_hidden_states
        gc.collect()

    def reset_mm_cache(self) -> None:
        pass

    def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
        # SP is not supported yet.
        return num_scheduled_tokens

    @torch.inference_mode()
    def capture_model(self) -> int:
        if not self.cudagraph_manager.needs_capture():
            logger.warning(
                "Skipping CUDA graph capture. To turn on CUDA graph capture, "
                "ensure `cudagraph_mode` was not manually set to `NONE`"
            )
            return 0

        start_time = time.perf_counter()
352
        gc.collect()
353
        torch.cuda.empty_cache()
Woosuk Kwon's avatar
Woosuk Kwon committed
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

        with self.maybe_setup_dummy_loras(self.lora_config):
            self.cudagraph_manager.capture(
                model=self.model,
                input_buffers=self.input_buffers,
                block_tables=self.block_tables,
                attn_metadata_builders=self.attn_metadata_builders,
                kv_cache_config=self.kv_cache_config,
            )

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info(
            "Graph capturing finished in %.0f secs, took %.2f GiB",
            elapsed_time,
            cuda_graph_size / (1 << 30),
        )
        return cuda_graph_size

    def warmup_for_prefill(self) -> None:
        # For FlashInfer, we would like to execute a dummy prefill run
        # to trigger JIT compilation.
        if all("FLASHINFER" in b.get_name() for b in self.attn_backends.values()):
            self._dummy_run(self.max_num_tokens, skip_attn=False)
            torch.cuda.synchronize()

    def update_states(self, scheduler_output: SchedulerOutput) -> None:
385
386
387
        if scheduler_output.preempted_req_ids is not None:
            for req_id in scheduler_output.preempted_req_ids:
                self.req_states.remove_request(req_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        for req_id in scheduler_output.finished_req_ids:
            self.req_states.remove_request(req_id)

        # TODO(woosuk): Change SchedulerOutput.
        req_indices: list[int] = []
        cu_num_new_blocks = tuple(
            [0] for _ in range(self.block_tables.num_kv_cache_groups)
        )
        new_block_ids: tuple[list[int], ...] = tuple(
            [] for _ in range(self.block_tables.num_kv_cache_groups)
        )
        overwrite: list[bool] = []

        # Add new requests.
        for new_req_data in scheduler_output.scheduled_new_reqs:
403
404
405
            assert new_req_data.prompt_token_ids is not None
            assert new_req_data.prefill_token_ids is not None
            assert new_req_data.sampling_params is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
            req_id = new_req_data.req_id
            self.req_states.add_request(
                req_id=req_id,
                prompt_len=len(new_req_data.prompt_token_ids),
                prefill_token_ids=new_req_data.prefill_token_ids,
                num_computed_tokens=new_req_data.num_computed_tokens,
                sampling_params=new_req_data.sampling_params,
                lora_request=new_req_data.lora_request,
            )

            req_index = self.req_states.req_id_to_index[req_id]
            req_indices.append(req_index)
            for i, block_ids in enumerate(new_req_data.block_ids):
                x = cu_num_new_blocks[i][-1]
                cu_num_new_blocks[i].append(x + len(block_ids))
                new_block_ids[i].extend(block_ids)
            overwrite.append(True)
423
424
425
        # Update the GPU tensors for request states.
        if scheduler_output.scheduled_new_reqs:
            self.req_states.prefill_len.copy_to_gpu()
Woosuk Kwon's avatar
Woosuk Kwon committed
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460

        # Add new blocks for the existing requests.
        cached_reqs = scheduler_output.scheduled_cached_reqs
        for i, req_id in enumerate(cached_reqs.req_ids):
            req_index = self.req_states.req_id_to_index[req_id]

            req_new_block_ids = cached_reqs.new_block_ids[i]
            if req_new_block_ids is not None:
                req_indices.append(req_index)
                for group_id, block_ids in enumerate(req_new_block_ids):
                    x = cu_num_new_blocks[group_id][-1]
                    cu_num_new_blocks[group_id].append(x + len(block_ids))
                    new_block_ids[group_id].extend(block_ids)
                overwrite.append(False)

        if req_indices:
            self.block_tables.append_block_ids(
                req_indices=req_indices,
                cu_num_new_blocks=cu_num_new_blocks,
                new_block_ids=new_block_ids,
                overwrite=overwrite,
            )

    def prepare_inputs(
        self,
        scheduler_output: SchedulerOutput,
        num_tokens_after_padding: int,
    ) -> InputBatch:
        num_tokens = scheduler_output.total_num_scheduled_tokens
        assert num_tokens > 0
        num_reqs = len(scheduler_output.num_scheduled_tokens)

        # Decode first, then prefill.
        # batch_idx -> req_id
        req_ids = sorted(
461
462
            scheduler_output.num_scheduled_tokens.keys(),
            key=lambda k: scheduler_output.num_scheduled_tokens[k],
Woosuk Kwon's avatar
Woosuk Kwon committed
463
464
465
466
467
468
469
470
471
472
473
474
475
        )
        num_scheduled_tokens = np.array(
            [scheduler_output.num_scheduled_tokens[i] for i in req_ids], dtype=np.int32
        )

        idx_mapping_list = [
            self.req_states.req_id_to_index[req_id] for req_id in req_ids
        ]
        idx_mapping = self.input_buffers.idx_mapping
        idx_mapping.np[:num_reqs] = idx_mapping_list
        idx_mapping_np = idx_mapping.np[:num_reqs]
        idx_mapping = idx_mapping.copy_to_gpu(num_reqs)

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        # Get the number of draft tokens for each request.
        if not scheduler_output.scheduled_spec_decode_tokens:
            # No draft token scheduled (common case).
            total_num_draft_tokens = 0
            total_num_logits = num_reqs
            cu_num_logits = torch.arange(
                num_reqs + 1, device=self.device, dtype=torch.int32
            )
        else:
            draft_tokens = scheduler_output.scheduled_spec_decode_tokens
            num_draft_tokens = np.array(
                [
                    len(draft_tokens[req_id]) if req_id in draft_tokens else 0
                    for req_id in req_ids
                ],
                dtype=np.int32,
            )
            total_num_draft_tokens = int(num_draft_tokens.sum())
            total_num_logits = num_reqs + total_num_draft_tokens

            np.cumsum(
                num_draft_tokens + 1,
                out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1],
            )
            cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1)

Woosuk Kwon's avatar
Woosuk Kwon committed
502
503
504
        # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
        block_tables = self.block_tables.gather_block_tables(idx_mapping)

505
506
507
508
509
510
511
512
513
514
515
516
517
        # Get query_start_loc.
        np.cumsum(
            num_scheduled_tokens,
            out=self.input_buffers.query_start_loc.np[1 : num_reqs + 1],
        )
        # Pad for full CUDA graph mode.
        # Some attention backends like FA3 require query_start_loc to be non-decreasing.
        self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
        self.input_buffers.query_start_loc.copy_to_gpu()
        query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
        query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]

        # Copy prefill tokens from CPU to GPU.
518
        prepare_prefill_inputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
519
520
            idx_mapping_np,
            num_scheduled_tokens,
521
            query_start_loc_np,
522
523
            self.req_states.prefill_token_ids,
            self.req_states.num_computed_prefill_tokens,
524
            self.input_buffers.input_ids.np,
Woosuk Kwon's avatar
Woosuk Kwon committed
525
        )
526
        self.input_buffers.input_ids.copy_to_gpu(num_tokens)
Woosuk Kwon's avatar
Woosuk Kwon committed
527

528
529
530
531
532
533
534
535
536
537
538
        # Prepare positions and seq_lens.
        prepare_pos_seq_lens(
            idx_mapping,
            query_start_loc_gpu,
            self.req_states.num_computed_tokens,
            self.input_buffers.positions,
            self.input_buffers.seq_lens,
        )
        seq_lens = self.input_buffers.seq_lens[:num_reqs]

        # Some input token ids are directly read from the last sampled tokens
539
540
        # and draft tokens. Also, get the logits indices to sample tokens from.
        logits_indices = combine_sampled_and_draft_tokens(
Woosuk Kwon's avatar
Woosuk Kwon committed
541
542
543
544
            self.input_buffers.input_ids.gpu,
            idx_mapping,
            self.req_states.last_sampled_tokens,
            query_start_loc_gpu,
545
546
            seq_lens,
            self.req_states.prefill_len.gpu,
547
548
549
            self.req_states.draft_tokens,
            cu_num_logits,
            total_num_logits,
Woosuk Kwon's avatar
Woosuk Kwon committed
550
551
552
553
        )

        # Compute slot mappings: [num_kv_cache_groups, num_tokens]
        slot_mappings = self.block_tables.compute_slot_mappings(
554
            query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
Woosuk Kwon's avatar
Woosuk Kwon committed
555
556
        )

557
558
559
560
561
562
563
564
565
566
        # Get num_computed_tokens.
        # HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
        # num_computed_tokens_cpu. This works for most cases.
        num_computed_tokens = self.req_states.num_computed_tokens[idx_mapping]
        # HACK(woosuk): Only GPU has the exact seq_lens because at this point
        # CPU does not know how many draft tokens are accepted/rejected in the
        # previous step. Therefore, we use max_model_len to be safe.
        # NOTE(woosuk): This only works for FA3 backend.
        seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)

Woosuk Kwon's avatar
Woosuk Kwon committed
567
568
569
570
571
572
573
        # Layer name -> attention metadata.
        attn_metadata = build_attn_metadata(
            attn_metadata_builders=self.attn_metadata_builders,
            num_reqs=num_reqs,
            num_tokens=num_tokens,
            query_start_loc=self.input_buffers.query_start_loc,
            seq_lens=self.input_buffers.seq_lens,
574
575
            seq_lens_np=seq_lens_np,
            num_computed_tokens_cpu=num_computed_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
576
577
578
579
580
581
            block_tables=block_tables,
            slot_mappings=slot_mappings,
            kv_cache_config=self.kv_cache_config,
        )

        input_ids = self.input_buffers.input_ids.gpu[:num_tokens_after_padding]
582
        positions = self.input_buffers.positions[:num_tokens_after_padding]
Woosuk Kwon's avatar
Woosuk Kwon committed
583
584
585
586
587
588
589
590
        return InputBatch(
            req_ids=req_ids,
            num_reqs=num_reqs,
            idx_mapping=idx_mapping,
            idx_mapping_np=idx_mapping_np,
            num_scheduled_tokens=num_scheduled_tokens,
            num_tokens=num_tokens,
            num_tokens_after_padding=num_tokens_after_padding,
591
            num_draft_tokens=total_num_draft_tokens,
Woosuk Kwon's avatar
Woosuk Kwon committed
592
593
            query_start_loc=query_start_loc_gpu,
            query_start_loc_np=query_start_loc_np,
594
            seq_lens=seq_lens,
Woosuk Kwon's avatar
Woosuk Kwon committed
595
596
597
598
599
            seq_lens_np=seq_lens_np,
            input_ids=input_ids,
            positions=positions,
            attn_metadata=attn_metadata,
            logits_indices=logits_indices,
600
            cu_num_logits=cu_num_logits,
Woosuk Kwon's avatar
Woosuk Kwon committed
601
602
603
604
605
606
607
608
        )

    def sample(
        self,
        hidden_states: torch.Tensor,
        input_batch: InputBatch,
        sampling_metadata: SamplingMetadata,
        grammar_output: GrammarOutput | None,
609
    ) -> tuple[SamplerOutput, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
610
611
612
613
        sample_hidden_states = hidden_states[input_batch.logits_indices]
        logits = self.model.compute_logits(sample_hidden_states)
        if grammar_output is not None:
            # Apply grammar bitmask to the logits in-place.
614
            # TODO(woosuk): Make compatible with spec decoding.
615
            assert input_batch.num_draft_tokens == 0
Woosuk Kwon's avatar
Woosuk Kwon committed
616
617
618
619
620
621
622
623
            with async_barrier(self.structured_outputs_event):
                apply_grammar_bitmask(
                    logits,
                    input_batch.req_ids,
                    grammar_output.structured_output_request_ids,
                    grammar_output.grammar_bitmask,
                    self.input_buffers,
                )
624

625
        # Sample tokens and compute logprobs (if needed).
Woosuk Kwon's avatar
Woosuk Kwon committed
626
        sampler_output = self.sampler(logits, sampling_metadata)
627

628
629
630
        # Get the number of sampled tokens.
        prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
        is_chunked_prefilling = input_batch.seq_lens < prefill_len
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
        if input_batch.num_draft_tokens == 0:
            # No draft tokens (common case).
            # 0 if chunked-prefilling, 1 if not.
            num_sampled = (~is_chunked_prefilling).int()
        else:
            # Draft tokens for spec decoding.
            input_ids = input_batch.input_ids[input_batch.logits_indices]
            sampled_tokens, num_sampled = rejection_sample(
                sampler_output.sampled_token_ids,
                input_ids,
                input_batch.cu_num_logits,
                self.num_speculative_steps,
            )
            num_sampled *= ~is_chunked_prefilling
            sampler_output.sampled_token_ids = sampled_tokens
            # TODO(woosuk): Support logprobs with spec decoding.
647
        return sampler_output, num_sampled
Woosuk Kwon's avatar
Woosuk Kwon committed
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662

    def compute_prompt_logprobs(
        self,
        hidden_states: torch.Tensor,
        input_batch: InputBatch,
    ) -> dict[str, LogprobsTensors]:
        idx_mapping_np = input_batch.idx_mapping_np
        needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np]
        if not np.any(needs_prompt_logprobs):
            # No request asks for prompt logprobs.
            return {}

        prompt_lens = self.req_states.prompt_len[idx_mapping_np]
        # NOTE(woosuk): -1 because the last prompt token's hidden state is not
        # needed for prompt logprobs.
663
664
        computed_prefill = self.req_states.num_computed_prefill_tokens[idx_mapping_np]
        includes_prompt = computed_prefill < prompt_lens - 1
Woosuk Kwon's avatar
Woosuk Kwon committed
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        # NOTE(woosuk): If the request was resumed after preemption, its prompt
        # logprobs must have been computed before preemption. Skip.
        resumed_after_prompt = (
            prompt_lens < self.req_states.prefill_len.np[idx_mapping_np]
        )
        needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
        if not np.any(needs_prompt_logprobs):
            return {}

        # Just to be safe, clone the input ids.
        n = input_batch.num_tokens
        # Shift the input ids by one.
        token_ids = torch.empty_like(input_batch.input_ids[:n])
        token_ids[: n - 1] = input_batch.input_ids[1:n]
        # To avoid out-of-bound access, set the last token id to 0.
        token_ids[n - 1] = 0

        # Handle chunked prompts.
683
684
        pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
        is_prompt_chunked = pos_after_step < prompt_lens
Woosuk Kwon's avatar
Woosuk Kwon committed
685
686
687
688
689
690
691
692
693
        prefill_token_ids = self.req_states.prefill_token_ids
        query_start_loc = self.input_buffers.query_start_loc.np
        for i, req_id in enumerate(input_batch.req_ids):
            if not needs_prompt_logprobs[i]:
                continue
            if not is_prompt_chunked[i]:
                continue
            # The prompt is chunked. Get the next prompt token.
            req_idx = input_batch.idx_mapping_np[i]
694
            next_prompt_token = int(prefill_token_ids[req_idx, pos_after_step[i]])
Woosuk Kwon's avatar
Woosuk Kwon committed
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
            idx = int(query_start_loc[i + 1] - 1)
            # Set the next prompt token.
            # NOTE(woosuk): This triggers a GPU operation.
            token_ids[idx] = next_prompt_token

        # NOTE(woosuk): We mask out logprobs for negative tokens.
        prompt_logprobs, prompt_ranks = compute_prompt_logprobs(
            token_ids,
            hidden_states[:n],
            self.model.compute_logits,
        )

        prompt_token_ids = token_ids.unsqueeze(-1)
        prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
        for i, req_id in enumerate(input_batch.req_ids):
            if not needs_prompt_logprobs[i]:
                continue

            start_idx = query_start_loc[i]
            end_idx = query_start_loc[i + 1]
            assert start_idx < end_idx, (
                f"start_idx ({start_idx}) >= end_idx ({end_idx})"
            )
            logprobs = LogprobsTensors(
                logprob_token_ids=prompt_token_ids[start_idx:end_idx],
                logprobs=prompt_logprobs[start_idx:end_idx],
                selected_token_ranks=prompt_ranks[start_idx:end_idx],
            )

            req_extra_data = self.req_states.extra_data[req_id]
            prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs
            if is_prompt_chunked[i]:
                # Prompt is chunked. Do not return the logprobs yet.
                prompt_logprobs_list.append(logprobs)
                continue

            if prompt_logprobs_list:
                # Merge the in-progress logprobs.
                prompt_logprobs_list.append(logprobs)
                logprobs = LogprobsTensors(
                    logprob_token_ids=torch.cat(
                        [x.logprob_token_ids for x in prompt_logprobs_list]
                    ),
                    logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
                    selected_token_ranks=torch.cat(
                        [x.selected_token_ranks for x in prompt_logprobs_list]
                    ),
                )
                prompt_logprobs_list.clear()

            prompt_logprobs_dict[req_id] = logprobs
        return prompt_logprobs_dict

    def postprocess(
        self,
        input_batch: InputBatch,
751
752
753
754
        sampled_tokens: torch.Tensor,
        num_sampled: torch.Tensor,
    ) -> None:
        # Update the number of computed tokens.
755
        post_update(
756
757
            input_batch.idx_mapping,
            self.req_states.num_computed_tokens,
758
759
760
            self.req_states.last_sampled_tokens,
            sampled_tokens,
            num_sampled,
761
            input_batch.query_start_loc,
762
            input_batch.cu_num_logits,
Woosuk Kwon's avatar
Woosuk Kwon committed
763
        )
764
765

        # Update the number of computed prefill tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
766
        idx_mapping_np = input_batch.idx_mapping_np
767
768
769
770
771
        computed_prefill = self.req_states.num_computed_prefill_tokens
        # TODO(woosuk): Simplify this.
        computed_prefill[idx_mapping_np] = np.minimum(
            computed_prefill[idx_mapping_np] + input_batch.num_scheduled_tokens,
            self.req_states.prefill_len.np[idx_mapping_np],
Woosuk Kwon's avatar
Woosuk Kwon committed
772
773
        )

774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
    @torch.inference_mode()
    def propose_draft(
        self,
        input_batch: InputBatch,
        sampling_metadata: SamplingMetadata,
        last_hidden_states: torch.Tensor,
        aux_hidden_states: list[torch.Tensor] | None,
        num_sampled: torch.Tensor,
    ) -> torch.Tensor:
        num_reqs = input_batch.num_reqs
        idx_mapping_np = input_batch.idx_mapping_np
        with async_barrier(self.spec_decode_event):
            self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
                self.req_states.prefill_token_ids[
                    idx_mapping_np,
                    self.req_states.num_computed_prefill_tokens[idx_mapping_np],
                ]
            )
            next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
                num_reqs
            )

        assert self.speculator is not None
        draft_tokens = self.speculator.propose(
            input_batch,
            sampling_metadata,
            last_hidden_states,
            aux_hidden_states,
            num_sampled,
            self.req_states.last_sampled_tokens,
            next_prefill_tokens,
        )
        self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
        return draft_tokens

Woosuk Kwon's avatar
Woosuk Kwon committed
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
    def get_cudagraph_and_dp_padding(
        self,
        scheduler_output: SchedulerOutput,
    ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
        if self.dp_size == 1:
            # No DP. Only consider CUDA graphs.
            if total_num_scheduled_tokens == 0:
                # Special case: no tokens to run.
                return CUDAGraphMode.NONE, 0, None

            cudagraph_size = self.cudagraph_manager.get_cudagraph_size(
                scheduler_output, total_num_scheduled_tokens
            )
            if cudagraph_size is not None:
                # Use full CUDA graph.
                return CUDAGraphMode.FULL, cudagraph_size, None
            # Fall back to eager mode.
            # TODO(woosuk): Support piecewise CUDA graphs.
            return CUDAGraphMode.NONE, total_num_scheduled_tokens, None

        # Consider DP padding and CUDA graph.
        if total_num_scheduled_tokens == 0:
            # Special handling is needed for 0.
            cudagraph_size_before_dp: int | None = 0
        else:
            cudagraph_size_before_dp = self.cudagraph_manager.get_cudagraph_size(
                scheduler_output, total_num_scheduled_tokens
            )
            if cudagraph_size_before_dp is None:
                cudagraph_size_before_dp = -1

        assert cudagraph_size_before_dp is not None
        num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp(
            total_num_scheduled_tokens,
            cudagraph_size_before_dp,
            self.dp_size,
            self.dp_rank,
        )
        if all(cudagraph_size_across_dp >= 0):
            # If all ranks can use CUDA graph, pad to the maximum number of tokens
            # across DP and use CUDA graph.
            num_tokens_after_padding = int(cudagraph_size_across_dp.max().item())
            cudagraph_mode = CUDAGraphMode.FULL
        else:
            # If any of the ranks cannot use CUDA graph, use eager mode for all ranks.
            # No padding is needed except for ranks that have no tokens to run.
            num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
            num_tokens_after_padding = num_tokens_across_dp[self.dp_rank]
            cudagraph_mode = CUDAGraphMode.NONE
        return cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp

    @torch.inference_mode()
    def execute_model(
        self,
        scheduler_output: SchedulerOutput,
        intermediate_tensors: Any | None = None,
        dummy_run: bool = False,
    ) -> ModelRunnerOutput | None:
        assert intermediate_tensors is None
        if scheduler_output.total_num_scheduled_tokens == 0 and not dummy_run:
            # No need to run the model.
            with async_barrier(self.input_prep_event):
                self.update_states(scheduler_output)
                return EMPTY_MODEL_RUNNER_OUTPUT

        # NOTE: Call this before the async barrier so CPU all-reduce and
        # GPU execution can overlap.
        cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = (
            self.get_cudagraph_and_dp_padding(scheduler_output)
        )
        with async_barrier(self.input_prep_event):
            self.update_states(scheduler_output)
            if num_tokens_after_padding == 0:
                # All DP ranks have zero tokens to run.
                return EMPTY_MODEL_RUNNER_OUTPUT

            if not dummy_run:
                # Common case.
                # Prepare all the inputs and copy to the input buffers.
                input_batch = self.prepare_inputs(
                    scheduler_output,
                    num_tokens_after_padding,
                )

                # NOTE(woosuk): Sampling metadata should be built under the async
                # barrier to avoid race conditions.
                pos = input_batch.positions[input_batch.logits_indices]
                sampling_metadata = self.req_states.make_sampling_metadata(
                    input_batch.idx_mapping_np, pos
                )
900
901
902
903
                if input_batch.num_draft_tokens > 0:
                    sampling_metadata = self.req_states.expand_sampling_metadata(
                        sampling_metadata, input_batch.cu_num_logits
                    )
Woosuk Kwon's avatar
Woosuk Kwon committed
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934

                if self.lora_config:
                    # Activate LoRA adapters.
                    lora_inputs = self.req_states.make_lora_inputs(
                        input_batch.req_ids,
                        input_batch.idx_mapping_np,
                        input_batch.num_scheduled_tokens,
                    )
                    self._set_active_loras(*lora_inputs)
            else:
                # No actual tokens to run. A dummy run for DP.
                num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
                input_batch = InputBatch.make_dummy(
                    num_reqs=num_reqs,
                    num_tokens=num_tokens_after_padding,
                    input_buffers=self.input_buffers,
                    device=self.device,
                )
                self.prepare_dummy_attn_metadata(input_batch)
                sampling_metadata = None

        # Run model.
        if cudagraph_mode == CUDAGraphMode.FULL:
            # Run CUDA graph.
            # NOTE(woosuk): Here, we don't need to pass the input tensors,
            # because they are already copied to the CUDA graph input buffers.
            hidden_states = self.cudagraph_manager.run(
                input_batch.num_tokens_after_padding
            )
        else:
            # Run PyTorch model in eager mode.
935
            # TODO(woosuk): Support piecewise CUDA graph.
Woosuk Kwon's avatar
Woosuk Kwon committed
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
            with set_forward_context(
                input_batch.attn_metadata,
                self.vllm_config,
                num_tokens=input_batch.num_tokens_after_padding,
                cudagraph_runtime_mode=cudagraph_mode,
                num_tokens_across_dp=num_tokens_across_dp,
            ):
                hidden_states = self.model(
                    input_ids=input_batch.input_ids,
                    positions=input_batch.positions,
                )

        self.execute_model_state = hidden_states, input_batch, sampling_metadata
        return None

    @torch.inference_mode()
    def sample_tokens(
        self,
        grammar_output: GrammarOutput | None,
    ) -> AsyncOutput | ModelRunnerOutput:
        assert self.execute_model_state is not None
        hidden_states, input_batch, sampling_metadata = self.execute_model_state
        self.execute_model_state = None  # type: ignore
        assert sampling_metadata is not None

961
        sampler_output, num_sampled_tokens = self.sample(
Woosuk Kwon's avatar
Woosuk Kwon committed
962
963
964
            hidden_states, input_batch, sampling_metadata, grammar_output
        )
        prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993

        # Prepare the model runner output.
        model_runner_output = ModelRunnerOutput(
            req_ids=input_batch.req_ids,
            # NOTE(woosuk): req_id_to_index is unused in this model runner.
            # Only for compatibility with the existing model runner and scheduler.
            req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
            sampled_token_ids=None,  # type: ignore
            logprobs=None,
            prompt_logprobs_dict=prompt_logprobs_dict,  # type: ignore
            pooler_output=[],
            kv_connector_output=None,
            num_nans_in_logits=None,
        )
        async_output = AsyncOutput(
            model_runner_output=model_runner_output,
            sampler_output=sampler_output,
            num_sampled_tokens=num_sampled_tokens,
            copy_stream=self.output_copy_stream,
            copy_event=self.output_copy_event,
        )

        # Postprocess results and update request states.
        # NOTE: This is intentionally done after creating the AsyncOutput,
        # ensuring that `copy_event` is recorded before calling postprocess.
        # This sequencing may slightly reduce latency as async D2H copy does not
        # need to wait for the postprocess to finish.
        self.postprocess(
            input_batch, sampler_output.sampled_token_ids, num_sampled_tokens
Woosuk Kwon's avatar
Woosuk Kwon committed
994
        )
995
996
997
998
999
1000
1001
1002
        if self.do_spec_decode:
            _ = self.propose_draft(
                input_batch,
                sampling_metadata,
                hidden_states,
                None,  # aux_hidden_states
                num_sampled_tokens,
            )
1003
1004
1005
1006

        if self.use_async_scheduling:
            return async_output
        return async_output.get_output()