multi_step_model_runner.py 38.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import dataclasses
import functools
6
from dataclasses import dataclass, field
7
8
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
                    Union)
9
10
11
12
13

import torch

from vllm.distributed import get_pp_group
from vllm.logger import init_logger
14
15
16
17
from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs,
                                                SamplerOutput,
                                                SamplingMetadata, get_logprobs,
                                                get_pythonized_sample_results)
18
from vllm.platforms import current_platform
19
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
20
                           Logprob, SequenceGroupMetadata, SequenceOutput)
youkaichao's avatar
youkaichao committed
21
from vllm.utils import PyObjectCache, async_tensor_h2d, current_stream
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.worker.model_runner import (GPUModelRunnerBase,
                                      ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
    BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
    _init_frozen_model_input_from_tensor_dict,
    _init_sampling_metadata_from_tensor_dict)

from ..model_executor.model_loader.tensorizer import TensorizerConfig

if TYPE_CHECKING:
    from vllm.attention.backends.abstract import AttentionBackend

logger = init_logger(__name__)

36
37
38
MULTI_STEP_ATTENTION_BACKENDS = [
    "FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION"
]
39
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"]
40
41
42
43
44
45
46

def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
    -> List[str]:
    if chunked_prefill_enabled:
        return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS
    else:
        return MULTI_STEP_ATTENTION_BACKENDS
47

48

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def seq_output_builder():
    return SequenceOutput(
        0, 0,
        {0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)})


def completion_seq_group_output_builder():
    return CompletionSequenceGroupOutput([], None)


# Used by pythonization to reduce python object allocations
class PythonizationCache:

    def __init__(self):
        self.cached_seq_output = PyObjectCache(seq_output_builder)
        self.cached_completion_seq_group_output = PyObjectCache(
            completion_seq_group_output_builder)

    def reset(self):
        self.cached_seq_output.reset()
        self.cached_completion_seq_group_output.reset()


72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
@dataclass
class ModelOutput:
    """The output of a single model forward pass.

    The sampler_output_ready_event is set when the tensors in
    sampler_output are ready (the model+sampler forward pass has
    completed). We use the event to synchronize the GPU->CPU transfer,
    which we want to only run when the data has been written to the
    GPU tensors. Until the event is ready, the tensors in sampler_output
    will have garbage data.

    There are two scenarios:
    1. The output tensors are ready and we can pythonize them immediately.
    2. The output tensors are not ready and we need to wait for the event to be
    ready.
    """
    sampler_output: SamplerOutput
    sampler_output_ready_event: torch.cuda.Event
    sampled_token_ids: Optional[torch.Tensor] = None
    pythonized: bool = False
92
93
    # On-device tensor containing the logprobs of each token.
    logprobs: Optional["torch.Tensor"] = None
94
    pythonization_cache: Optional[PythonizationCache] = None
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    def pythonize(self, input_metadata: "StatefulModelInput",
                  copy_stream: torch.cuda.Stream,
                  pinned_sampled_token_buffer: torch.Tensor) -> None:
        """Pythonize the output. Blocking."""
        if not self.pythonized:
            self._pythonize_sampler_output(input_metadata, copy_stream,
                                           pinned_sampled_token_buffer, True)
            self.pythonized = True

    def maybe_pythonize(self, input_metadata: "StatefulModelInput",
                        copy_stream: torch.cuda.Stream,
                        pinned_sampled_token_buffer: torch.Tensor) -> None:
        """Pythonize the output if ready, else return None. Non-blocking."""
        if not self.pythonized:
            self.pythonized = self._pythonize_sampler_output(
                input_metadata, copy_stream, pinned_sampled_token_buffer,
                False)

    def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput",
                                  copy_stream: torch.cuda.Stream,
                                  pinned_sampled_token_buffer: torch.Tensor,
                                  blocking: bool) -> bool:
        """
        If blocking is set, will block until the forward pass for the output is
120
121
122
        ready and pythonize the output. Upon completing Pythonization, erases
        self.logprobs (note that a non-blocking call that is performed when
        the sampler output is not yet ready, will not erase self.logprobs.)
123
124
125
126
127
128
129
130
131
132
        """
        assert self.sampled_token_ids is not None
        if not blocking and not self.sampler_output_ready_event.query():
            return False

        if blocking:
            self.sampler_output_ready_event.synchronize()
        with torch.cuda.stream(copy_stream):
            _pythonize_sampler_output(input_metadata, self.sampler_output,
                                      pinned_sampled_token_buffer,
133
134
                                      self.sampled_token_ids, self.logprobs,
                                      self.pythonization_cache)
135
136
137
138
139
140
141
142

        # Erase the logprobs GPU-side tensor.
        # Note that although _pythonize_sampler_output() runs in its
        # own CUDA stream, nonetheless _pythonize_sampler_output()
        # cannot return until Pythonization is complete; therefore
        # we know that by the time the CPU reaches this point,
        # `self.logprobs` is no longer needed.
        self.logprobs = None
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        return True


@dataclass(frozen=False)
class StatefulModelInput(BroadcastableModelInput):
    # actual frozen model input dataclass passed to _base_model_runner
    frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None

    # list of model outputs for each step, may not be all pythonized
    cached_outputs: List[ModelOutput] = field(default_factory=list)

    # used to pass sampled token ids from the last step to the current step for
    # TP workers. Used to append to end of outputs and used by advance_step
    last_sampled_token_ids: Optional[torch.Tensor] = None
    current_step: int = 0
    is_multi_step: bool = True
    is_last_step: bool = False
    is_first_multi_step: bool = False
161
    base_output_proc_callback: Optional[Callable] = None
162
    # ping-pong data structures for multi-step to wait on the previous step
163
164
    step_cuda_events: List[current_platform.Event] = field(
        default_factory=lambda: [current_platform.Event(blocking=True)] * 2)
165
166
    num_seqs: int = -1
    num_queries: int = -1
167
    num_single_step_prefills: int = 0
168
169
170
171
172
173
174
175
176
177
178
179

    def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
        assert self.frozen_model_input is not None
        tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict()
        new_tensor_dict = {
            'last_sampled_token_ids': self.last_sampled_token_ids,
            'current_step': self.current_step,
            'is_multi_step': self.is_multi_step,
            'is_last_step': self.is_last_step,
            'is_first_multi_step': self.is_first_multi_step,
            'num_seqs': self.num_seqs,
            'num_queries': self.num_queries,
180
            'num_single_step_prefills': self.num_single_step_prefills,
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        }
        tensor_dict.update(new_tensor_dict)
        return tensor_dict

    @classmethod
    def from_broadcasted_tensor_dict(
        cls,
        tensor_dict: Dict[str, Any],
        attn_backend: Optional["AttentionBackend"] = None,
    ) -> "StatefulModelInput":
        tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
        if attn_backend is not None:
            tensor_dict = _init_attn_metadata_from_tensor_dict(
                attn_backend, tensor_dict)
        tensor_dict = _init_frozen_model_input_from_tensor_dict(
            ModelInputForGPUWithSamplingMetadata, tensor_dict)

        return cls(**tensor_dict)

    def record_step_event(self, current_stream: torch.cuda.Stream):
        # record the event for the current step so that the next step can sync
        # on it. We modulo by 2 to keep the events in a circular buffer and
        # support any attn backends that may be supported in the future. ie
        # Flashinfer would want two DecodeWrappers to overlap the CPU and GPU.
        self.step_cuda_events[self.current_step & 1] = \
            torch.cuda.Event(blocking=True)
        self.step_cuda_events[self.current_step & 1].record(current_stream)

    def wait_previous_step(self):
        # These cuda events are an explicit synchronization to ensure that
        # advance_step() (for other attn backends that may be supported in the
        # future) do not clobber any data structures that is also used by any
        # enqueued forwards steps. For distributed case, only a single event is
        # needed, but for single GPU case, since we can let the CPU run much
        # further ahead, two events allow us to overlap the advance_step with
        # the previous forward (ie using two DecodeWrappers for flashinfer
        # backend)
        self.step_cuda_events[(self.current_step + 1) & 1].wait()

    def add_sampler_output(self,
                           sampler_output: SamplerOutput,
                           sampled_token_ids: Optional[torch.Tensor] = None):
        self.cached_outputs.append(
            ModelOutput(sampler_output=sampler_output,
                        sampler_output_ready_event=None,
                        sampled_token_ids=sampled_token_ids,
                        pythonized=False))

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
    def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool):
        """
        sampling_metadata.selected_token_indices is constructed for the
        first-step in Multi-Step. However, when chunked-prefill is enabled with
        multi-step, the scheduled prompts are fully processed in the
        first-step and are processed as decodes in the rest of the steps.
        This function updates the sampling_metadata.selected_token_indices
        to account for this conversion.

        Example:
        Let 2 prompts and 2 decodes be scheduled together. Let the
        num-tokens to process for the 2 prompts be 5 and 8 respectively.

        In that case, sampling_metadata.sampled_token_indices will be,
        [4, 12, 13, 14] as it is constructed for the first-step in
        multi-step.
        However, the prompts turns to decodes after the first-step
        and the num-tokens for the previously-prompt sequences will
        be 1 and 1 as they are decodes now. The self.sampled_token_indices
        must be updated to [0,1,2,3].
        """
        assert self.current_step == 1 and self.num_single_step_prefills > 0
        if not get_pp_group().is_last_rank:
            return

        assert self.frozen_model_input is not None
        assert self.frozen_model_input.sampling_metadata is not None
        self.frozen_model_input.sampling_metadata.selected_token_indices =  \
            async_tensor_h2d(list(range(self.num_queries)),
                             dtype=torch.long,
                             target_device=device,
                             pin_memory=pin_memory)

    def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool):
        """
        Advancing the datastructures of StatefulModelInput::frozen_model_input
        is only required when prefills are scheduled with decodes to run in
        multi-step. This advancement/correction is required to account for
        the conversion of Prefills to Decodes after the first multi-step.
        """
        if self.current_step != 1 or self.num_single_step_prefills == 0:
            return

        assert self.frozen_model_input is not None
        fmi = self.frozen_model_input

        # Truncate input_tokens
        assert fmi.input_tokens is not None
        assert fmi.input_tokens.shape[0] >= self.num_seqs
        fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs]

280
        # Update frozen_model_input::input_positions.
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
        assert fmi.input_positions is not None
        assert fmi.input_positions.shape[0] >= self.num_seqs
        fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self.
                                                                    num_seqs]

        # Assert unsupported
        assert fmi.lora_mapping is None
        assert fmi.lora_requests is not None
        assert len(fmi.lora_requests) == 0
        assert fmi.attn_metadata is not None
        assert fmi.multi_modal_kwargs is not None
        assert len(fmi.multi_modal_kwargs) == 0

        self.frozen_model_input = dataclasses.replace(
            self.frozen_model_input,
            input_tokens=fmi_new_input_tokens,
            input_positions=fmi_new_input_positions)

        self.maybe_advance_sampling_metadata(device, pin_memory)

301
302
303
304
305
306
307
308
309

# MutableModelInputForGPUWithMultiStepMetadata is not subclass of
# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step
# metadata
# mypy: disable-error-code=type-var
class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
    # mypy: enable-error-code=type-var

    def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs):
310

311
312
        super().__init__(*args, **kwargs)

313
314
315
316
317
318
319
320
321
322
323
324
325
        # Check attention backend support.
        supported_attention_backends: List[str] = \
            _get_supported_attention_backends(
                self.scheduler_config.chunked_prefill_enabled)
        if self.attn_backend.get_name() not in supported_attention_backends:
            ms_config_str: str = "Multi-Step + Chunked-Prefill" \
                if self.scheduler_config.chunked_prefill_enabled \
                      else "Multi-Step"
            raise ValueError(
                f"{ms_config_str} not supported for attention backend: "
                f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND "
                f"to a value from {supported_attention_backends}.")

326
327
328
329
330
331
332
        # uses the base model runner to execute the model and wraps it with
        # multi-step logic
        self._base_model_runner: GPUModelRunnerBase = base_model_runner

        self.is_multi_step = self.scheduler_config.is_multi_step
        self.pinned_sampled_token_ids: Optional[torch.Tensor] = None

333
334
335
336
337
338
339
340
        # Using the PythonizationCache in Pipeline-Parallel clobbers the
        # SequenceOutput and CompletionSequenceGroupOutput object.
        # When cache-reset happens at the last step of a multi-step
        # execution, there may be other on-going single-step/multi-step
        # executions. The current caching implementation does not check
        # for this.
        self.pythonization_cache = PythonizationCache() \
            if self.parallel_config.pipeline_parallel_size == 1 else None
341

342
343
344
345
346
    @functools.cached_property
    def _copy_stream(self):
        # used to copy tensors from GPU to CPU asynchronously
        return torch.cuda.Stream()

347
348
349
350
351
352
353
354
355
356
357
358
359
360
    def make_model_input_from_broadcasted_tensor_dict(
            self, tensor_dict: Dict[str, Any]) -> StatefulModelInput:
        model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
            tensor_dict,
            attn_backend=self.attn_backend,
        ))
        return model_input

    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        virtual_engine: int = 0,
        finished_requests_ids: Optional[List[str]] = None
    ) -> StatefulModelInput:
361
362
363
364
365
366
367
368
369
370
371
372
        frozen_model_input: ModelInputForGPUWithSamplingMetadata = \
              self._base_model_runner.prepare_model_input(
                    seq_group_metadata_list,
                    virtual_engine,
                    finished_requests_ids)

        assert frozen_model_input.query_lens is not None
        assert frozen_model_input.seq_lens is not None
        assert frozen_model_input.attn_metadata is not None
        num_queries = len(frozen_model_input.query_lens)
        num_seqs = len(frozen_model_input.seq_lens)
        num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills
373
374
375

        model_input = StatefulModelInput(
            frozen_model_input=frozen_model_input,
376
377
378
379
            num_seqs=num_seqs,
            num_queries=num_queries,
            num_single_step_prefills=num_single_step_prefills)

380
381
        return model_input

382
383
384
385
    def _async_process_outputs(self, model_input: StatefulModelInput,
                               output_proc_callback: Callable):
        # Proceed with pythonization and output_proc in order.
        # Stop on the first one that fails to pythonize
386
387
        output_proc_callback()

388
        cont = True
389
        for step_num, model_output in enumerate(model_input.cached_outputs):
390
391
392
393
            if not model_output.pythonized:
                model_output.maybe_pythonize(model_input, self._copy_stream,
                                             self.pinned_sampled_token_ids)
                if model_output.pythonized:
394
                    ctx = output_proc_callback.keywords["ctx"]
395
396
397
398
399
                    ctx.append_output(
                        outputs=[model_output.sampler_output],
                        seq_group_metadata_list=ctx.seq_group_metadata_list,
                        scheduler_outputs=ctx.scheduler_outputs,
                        is_async=False,
400
401
                        is_last_step=False,
                        is_first_step_output=step_num == 0)
402

403
                    output_proc_callback()
404
405
406
407
408
409
                else:
                    cont = False

            if not cont:
                break

410
411
412
    def _final_process_outputs(
            self, model_input: StatefulModelInput,
            output_proc_callback: Optional[Callable]) -> List[SamplerOutput]:
413
414
        assert model_input.frozen_model_input is not None

415
416
        has_async_callback = output_proc_callback is not None

417
        outputs = []
418
419
        for step_num, output in enumerate(model_input.cached_outputs):
            is_last_step = step_num == len(model_input.cached_outputs) - 1
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441

            # For non-async case:
            #   -- We simply add the outputs
            # For async case:
            #   -- Invoke callback, pythonize, add to callback queue and repeat
            #   -- For last output, just add to callback queue
            if has_async_callback:
                assert output_proc_callback is not None

                # Invoke callback before pythonize (to overlap with GPU)
                output_proc_callback()

                # Pythonize
                if not output.pythonized:
                    output.pythonize(model_input, self._copy_stream,
                                     self.pinned_sampled_token_ids)

                    # For non last step, add to callback queue to chain
                    # callbacks=>pythonize pairs (for GPU overlap)
                    if not is_last_step:
                        ctx = output_proc_callback.keywords[  # type: ignore
                            "ctx"]  # type: ignore
442
443
444
445
446
447
                        ctx.append_output(
                            outputs=[output.sampler_output],
                            seq_group_metadata_list=ctx.
                            seq_group_metadata_list,
                            scheduler_outputs=ctx.scheduler_outputs,
                            is_async=False,
448
449
                            is_last_step=False,
                            is_first_step_output=step_num == 0)
450
451
452
                    else:
                        outputs.append(output.sampler_output)
            else:
453
454
                output.pythonize(model_input, self._copy_stream,
                                 self.pinned_sampled_token_ids)
455
                outputs.append(output.sampler_output)
456
457
458

        return outputs

459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
    @torch.inference_mode()
    def execute_model(
        self,
        model_input: StatefulModelInput,
        kv_caches: List[torch.Tensor],
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
    ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
        """ 
        Execute the model for a single step and update multi-step
        metadata
        """
        assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1"
        frozen_model_input = model_input.frozen_model_input
        assert frozen_model_input is not None

        # path for warm up runs
        if not model_input.is_multi_step:
            return self._base_model_runner.execute_model(
478
                frozen_model_input, None, intermediate_tensors, num_steps)
479
480
481
482
483
484
485
486
487
488
489

        # make sure we skip the sampler on the lask rank and only pythonize
        # if CPU is ahead.
        if self.is_driver_worker and get_pp_group().is_last_rank:
            if self.pinned_sampled_token_ids is None:
                self.pinned_sampled_token_ids = torch.zeros(
                    (self.scheduler_config.max_num_seqs, 1),
                    dtype=torch.long,
                    device="cpu",
                    pin_memory=True)

490
            self._base_model_runner.sampler.include_gpu_probs_tensor = True
491
492
493
494
495
496
497
498
499
500
            if frozen_model_input.sampling_metadata:
                frozen_model_input.sampling_metadata.skip_sampler_cpu_output = (
                    True)

        # some pre-execute model logic for multi-step:
        #   - if it's the first step, we need to reset the sampling tensors
        #   - if it's not the first step, we need to advance the step using the
        #   appended sampler output from last iteration
        #   - also maybe pythonize if CPU is ahead of GPU

youkaichao's avatar
youkaichao committed
501
        stream = current_stream()
502
503
504
505
506
507
508
509
510
511
512
513
        if not model_input.is_first_multi_step:
            # Explicitly block on the previous step's forward to make sure we
            # don't clobber any GPU tensors still in use.
            # This is not needed for flashattn backend, but for other attn
            # backends such as flashinfer that performs extra CPU operations on
            # input metadata we may need to synchronize any CPU operations that
            # might clobber enqueued forwards. (prevents CPU from running too
            # far ahead if needed)
            model_input.wait_previous_step()
            model_input = self._advance_step(
                model_input, model_input.cached_outputs[-1].sampler_output)

514
515
516
517
518
519
520
521
522
            # frozen_model_input may have been updated
            frozen_model_input = model_input.frozen_model_input
            assert frozen_model_input is not None

        if model_input.base_output_proc_callback is None:
            assert frozen_model_input is not None
            model_input.base_output_proc_callback = \
                        frozen_model_input.async_callback

523
        if frozen_model_input.async_callback is not None:
524
            assert model_input.base_output_proc_callback is not None
525
526
527
            async_callback = functools.partial(
                self._async_process_outputs,
                model_input=model_input,
528
                output_proc_callback=model_input.base_output_proc_callback)
529

530
            model_input.frozen_model_input = dataclasses.replace(  # type: ignore
531
532
                model_input.frozen_model_input,
                async_callback=async_callback)
533
534
            # Update the local instance
            frozen_model_input = model_input.frozen_model_input
535
536
            assert frozen_model_input is not None

537
538
        # Execute the model
        output = self._base_model_runner.execute_model(frozen_model_input,
539
                                                       None,
540
541
542
543
                                                       intermediate_tensors,
                                                       num_steps=1)

        # record the event for the current step so that the next step can sync
youkaichao's avatar
youkaichao committed
544
        model_input.record_step_event(stream)
545
546

        if get_pp_group().is_last_rank and self.is_driver_worker:
547
            assert isinstance(output, list)
548
549
550
551
552
553
554
            assert len(
                output
            ) == 1, "MultiStepModelRunner requires single-step base_models"

            # event for the pythonization so that we only pythonize if the
            # tensors are ready. May be able to be combined with the step event
            output_ready_event = torch.cuda.Event()
youkaichao's avatar
youkaichao committed
555
            output_ready_event.record(stream)
556
557
558
559
560
            if self.parallel_config.pipeline_parallel_size > 1:
                output[0].sampled_token_ids_cpu = output[
                    0].sampled_token_ids.cpu()
            model_input.cached_outputs.append(
                ModelOutput(output[0], output_ready_event,
561
                            output[0].sampled_token_ids, False,
562
                            output[0].logprobs, self.pythonization_cache))
563
564
565
566

            # These GPU tensors are not required by multi-step;
            # erase them to ensure they are not pythonized or
            # transferred to CPU
567
568
569
            output[0].sampled_token_ids = None
            output[0].sampled_token_probs = None
            output[0].logprobs = None
570

571
572
            # Pythonize the output if CPU is ahead and the previous step is
            # ready.
573
            if frozen_model_input.async_callback is None:
574
575
576
577
                for model_output in model_input.cached_outputs:
                    model_output.maybe_pythonize(model_input,
                                                 self._copy_stream,
                                                 self.pinned_sampled_token_ids)
578
579
580
581
582
583
584
585
586
587
588
589

        model_input.current_step += 1

        if not get_pp_group().is_last_rank:
            # Should be IntermediateTensors
            assert isinstance(output, IntermediateTensors)
            return output
        if not self.is_driver_worker:
            return []

        # Pythonize the output and block if needed since it is the last step
        if model_input.is_last_step:
590
591
            outputs = self._final_process_outputs(
                model_input, model_input.base_output_proc_callback)
592
593
            if self.pythonization_cache:
                self.pythonization_cache.reset()
594
595
596
597
598
            return outputs

        # should be [SamplerOutput]
        return output

599
600
    def _update_sampling_metadata(self, sampling_metadata: SamplingMetadata,
                                  num_seqs: Optional[int], num_queries: int):
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619

        assert sampling_metadata.num_prompts == 0
        assert len(sampling_metadata.seq_groups) == num_queries
        assert sampling_metadata.selected_token_indices.shape == (
            num_queries, )
        # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501

        # Verify that all sequences are decodes
        for i in range(num_queries):
            seq_group = sampling_metadata.seq_groups[i]

            assert seq_group.is_prompt is False  # No prompt
            assert seq_group.prompt_logprob_indices == []  # No prompt
            assert seq_group.sample_indices == [i]  # Simple
            assert seq_group.seq_len is None  # Decode
            assert seq_group.query_len is None  # Decode

    def _advance_step(self, model_input: StatefulModelInput,
                      out: SamplerOutput) -> StatefulModelInput:
620
621
622
623
624
625
626
627

        model_input.maybe_advance_frozen_model_input(self.device,
                                                     self.pin_memory)
        frozen_model_input = model_input.frozen_model_input
        assert frozen_model_input is not None
        assert frozen_model_input.input_tokens is not None
        assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs
        assert frozen_model_input.attn_metadata is not None
628

629
        sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
630
631
        num_seqs = model_input.num_seqs
        num_queries = model_input.num_queries
632
633
        frozen_model_input = model_input.frozen_model_input
        assert frozen_model_input is not None
634
        attn_metadata = frozen_model_input.attn_metadata
635
        assert attn_metadata is not None
636

637
638
        turn_prefills_into_decodes: bool = model_input.current_step == 1 and \
                                    model_input.num_single_step_prefills != 0
639
640
        attn_metadata.advance_step(
            frozen_model_input,
641
642
643
644
            sampled_token_ids,
            self.block_size,
            num_seqs,
            num_queries,
645
            turn_prefills_into_decodes=turn_prefills_into_decodes)
646
647
648
649

        return model_input

    def load_model(self) -> None:
650
651
        self._base_model_runner.load_model()
        self.model_memory_usage = self._base_model_runner.model_memory_usage
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679

    def save_sharded_state(
        self,
        path: str,
        pattern: Optional[str] = None,
        max_size: Optional[int] = None,
    ) -> None:
        return self._base_model_runner.save_sharded_state(
            path, pattern, max_size)

    def save_tensorized_model(self,
                              tensorizer_config: TensorizerConfig) -> None:
        return self._base_model_runner.save_tensorized_model(tensorizer_config)

    def profile_run(self) -> None:
        return self._base_model_runner.profile_run()

    def remove_all_loras(self):
        return self._base_model_runner.remove_all_loras()

    def capture_model(self, kv_caches: List[List]) -> None:
        return self._base_model_runner.capture_model(kv_caches)

    @property
    def vocab_size(self) -> int:
        return self._base_model_runner.vocab_size


680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]],
                                   Optional[List[SampleLogprobs]]]


def deferred_pythonize_logprobs(
    output: SamplerOutput,
    sampling_metadata: SamplingMetadata,
    logprobs_tensor: Optional[torch.Tensor],
) -> DeferredLogprobsReturnType:
    """Perform deferred logprob Pythonization.

    1. Pythonize GPU-side sampler result tensors into CPU-side sampler result.
    2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists,
       utilizing  the Pythonized sampler result computed in step 1.
    
    These deferred computations are not required for single-step scheduling
    or the `profile_run()` phase of multi-step scheduling.

    Args:
        output: sampler output (under deferred Pythonization)
        sampling_metadata
        
    Returns:
        prompt_logprobs (CPU), sample_logprobs (CPU)
    """

    # - Deferred pythonization of sample result
    sampler_result = get_pythonized_sample_results(
        output.deferred_sample_results_args)

    # - Erase the GPU-side deferred sample_result
    #   computation args to ensure it is never
    #   pythonized or transferred to CPU
    output.deferred_sample_results_args = None

    # - Deferred pythonization of logprobs
    (
        prompt_logprobs,
        sample_logprobs,
    ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result)
    assert len(prompt_logprobs) == len(sampling_metadata.seq_groups)
    assert len(sample_logprobs) == len(sampling_metadata.seq_groups)

    return prompt_logprobs, sample_logprobs


def _pythonize_sampler_output(
    model_input: StatefulModelInput,
    output: SamplerOutput,
    pinned_sampled_token_buffer: torch.Tensor,
    sampled_token_ids: torch.Tensor,
    logprobs_tensor: Optional[torch.Tensor],
732
    cache: Optional[PythonizationCache],
733
) -> None:
734
735
736
737
    """ This function is only called when the output tensors are ready.
    See [`ModelOutput`][vllm.worker.multi_step_model_runner.ModelOutput].

    Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place,
738
    adding a Pythonized output data structure
739
740
    ([`CompletionSequenceGroupOutput`][vllm.sequence.CompletionSequenceGroupOutput])
    for each [`SequenceGroup`][vllm.sequence.SequenceGroup].
741
742
743
744
745
746
747
748
749
750

    Args:
      model_input
      output: sampler output
      pinned_sampled_token_token_buffer: CPU-side pinned memory
                                         (receives copy of
                                         GPU-side token buffer.)
      sampled_token_ids: GPU-side token buffer
      logprobs_tensor: GPU-side tensor containing 
                       logprobs computed during sampling
751
752
753
754
755
756
    """

    assert model_input.frozen_model_input is not None

    frozen_model_input = model_input.frozen_model_input
    assert frozen_model_input.sampling_metadata is not None
757
    sampling_metadata = frozen_model_input.sampling_metadata
758
759
760
761
762
    # samples generation should have been skipped
    assert not output.outputs

    pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]

763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
    # We guarantee output tensors are ready, so it is safe to
    # pythonize the sampler output & obtain CPU-side logprobs.
    #
    # However we should check whether logprobs pythonization may
    # be skipped entirely, i.e. because no logprobs were requested
    # or pythonization was not deferred. To that end,
    #
    # * `prompt_logprobs_are_requested_for_prefill` signals that
    #   there are *any* prefill-phase requests which specify that
    #   prompt logprobs should be returned.
    #
    # * `any_logprobs_are_requested` signals that there are any
    #   requests which (1) specify that sample logprobs should be
    #   returned, or (2) are in the prefill phase AND specify that
    #   prompt logprobs should be returned.
    #
    # Later on, these flags cause adjustments to the pythonization
    # process to accommodate logprobs.

    seq_groups = sampling_metadata.seq_groups
    prompt_logprobs_are_requested_for_prefill = any([
        sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
        for sg in seq_groups
    ])
    any_logprobs_are_requested = (
        prompt_logprobs_are_requested_for_prefill
        or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))

    if prompt_logprobs_are_requested_for_prefill:
        # CPU GPU sync, after gathering *only* sampled tokens (since
        # requesting prompt logprobs leads `sampled_token_ids` to
        # include prompt token ids in addition to sampled token ids.)
        sample_idx_tensor = torch.tensor(
            [sdx for sg in seq_groups for sdx in sg.sample_indices])
        pinned_buffer = pinned_buffer.copy_(
            sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
    else:
        # CPU GPU sync
        pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
                                            non_blocking=False)
803
804
805
806

    # this will not block as the tensors are already on CPU
    samples_list = pinned_buffer.tolist()

807
808
809
    skip_sampler_cpu_output = (
        frozen_model_input.sampling_metadata.skip_sampler_cpu_output)

810
811
812
813
814
    # *Don't* skip logprobs pythonization *if*:
    # * Any requests require logprobs to be returned in this
    # iteration AND
    # * These requests are being scheduled in a fashion which
    # defers pythonization (i.e. multi-step scheduling.)
815
    do_pythonize_logprobs = (skip_sampler_cpu_output
816
                             and any_logprobs_are_requested)
817
818
819
820
821
822
823
824
825
    (
        prompt_logprobs,
        sample_logprobs,
    ) = (deferred_pythonize_logprobs(output, sampling_metadata,
                                     logprobs_tensor)
         if do_pythonize_logprobs else (None, None))

    for sgdx, (seq_group,
               sample_result) in enumerate(zip(seq_groups, samples_list)):
826
        # Reminder: Please update docs/features/compatibility_matrix.md
827
828
        # If the feature combo become valid
        # (Check for Guided Decoding)
829
830
831
        if seq_group.sampling_params.logits_processors:
            assert len(seq_group.sampling_params.logits_processors) == 0, (
                "Logits Processors are not supported in multi-step decoding")
832
833
834
835
836
837
838
839
840
841
842
843

        if do_pythonize_logprobs:
            assert prompt_logprobs is not None
            assert sample_logprobs is not None

            (
                group_prompt_logprobs,
                group_sample_logprobs,
            ) = (  # Utilize deferred pythonization results
                prompt_logprobs[sgdx],
                sample_logprobs[sgdx],
            )
844
        elif any_logprobs_are_requested:
845
846
847
848
849
850
851
852
            (
                group_prompt_logprobs,
                group_sample_logprobs,
            ) = (
                # profile_run: use already-computed logprobs
                output.outputs[sgdx].prompt_logprobs,
                [sample.logprobs for sample in output.outputs[sgdx].samples])

853
854
855
        seq_ids = seq_group.seq_ids
        next_token_ids = sample_result
        parent_ids = [0]
856
        seq_outputs: List[SequenceOutput]
857
858
859
860
861

        if cache is not None:
            completion_seq_group_output: CompletionSequenceGroupOutput = \
                cache.cached_completion_seq_group_output.get_object()
            completion_seq_group_output.samples.clear()
862
            seq_outputs = completion_seq_group_output.samples
863
864
865
        else:
            seq_outputs = []

866
867
        for tdx, (parent_id,
                  next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
868
869
870
871
872
873
            if cache is not None:
                seq_output: SequenceOutput = cache.cached_seq_output.get_object(
                )
                seq_output.parent_seq_id = seq_ids[parent_id]
                seq_output.output_token = next_token_id

874
                if any_logprobs_are_requested:
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
                    seq_output.logprobs = group_sample_logprobs[tdx]
                else:
                    logprobs = next(iter(seq_output.logprobs.values()))
                    seq_output.logprobs.clear()

                    logprobs.logprob = float('inf')
                    logprobs.rank = None
                    logprobs.decoded_token = None

                    seq_output.logprobs[next_token_id] = logprobs

                seq_outputs.append(seq_output)

            else:
                seq_outputs.append(
                    SequenceOutput(seq_ids[parent_id], next_token_id,
                                   (group_sample_logprobs[tdx]
892
                                    if any_logprobs_are_requested else {
893
894
895
896
897
898
899
                                        next_token_id:
                                        Logprob(logprob=float('inf'),
                                                rank=None,
                                                decoded_token=None)
                                    })))
        if cache is not None:
            completion_seq_group_output.prompt_logprobs = \
900
                group_prompt_logprobs if any_logprobs_are_requested else None
901
902
903
904
905
            output.outputs.append(completion_seq_group_output)
        else:
            output.outputs.append(
                CompletionSequenceGroupOutput(
                    seq_outputs, (group_prompt_logprobs
906
                                  if any_logprobs_are_requested else None)))
907

908
    assert len(output.outputs) > 0