single_step.py 6.27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from typing import List
5
6
7
8
9
10
11

from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
12
13
from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup,
                           SequenceGroupOutput)
14
from vllm.transformers_utils.detokenizer import Detokenizer
15
from vllm.utils import Counter
16
17
18
19

logger = init_logger(__name__)


20
21
def single_step_process_prompt_logprob(
        sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup,
22
        output: CompletionSequenceGroupOutput) -> None:
23
24
    """Process prompt logprobs associated with the
    [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput] for a given step.
25
26
27
28
29
30

    Do nothing if the output has no prompt logprobs.

    Account for the fact that transformers do not compute first-token logprobs.
    
    Args:
31
32
33
34
35
36
37
      sg_output_proc:
          [`SequenceGroupOutputProcessor`][vllm.engine.output_processor.interfaces.SequenceGroupOutputProcessor]
          instance
      seq_group: the output is associated with this
          [`SequenceGroup`][vllm.sequence.SequenceGroup]
      output: the [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
          for a single scheduler step
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    """
    prompt_logprobs = output.prompt_logprobs

    # If this is the first (or only) "chunk" of the prefill, we need
    # to prepend None to the list of prompt logprobs. The reason for this
    # is that for N prompt tokens, the Sampler will generate N-1 total
    # prompt logprobs during prefill since the token at idx 0 will not
    # have a logprob associated with it.
    if prompt_logprobs is not None:
        if not seq_group.prompt_logprobs:
            prompt_logprobs = [None] + prompt_logprobs
            seq_group.prompt_logprobs = []

        assert hasattr(sg_output_proc, 'detokenizer')
        if (seq_group.sampling_params.detokenize
                and sg_output_proc.detokenizer):
            sg_output_proc.detokenizer.decode_prompt_logprobs_inplace(
                seq_group,
                prompt_logprobs,
                position_offset=len(seq_group.prompt_logprobs))

        seq_group.prompt_logprobs.extend(prompt_logprobs)


62
63
64
65
66
67
68
69
70
71
72
73
74
75
class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
    """SequenceGroupOutputProcessor which handles "output processing" logic,
    which happens after the model returns generated token ids and before
    scheduling of the next batch. Output processing logic includes
    detokenization, and determining if a sequence is finished (e.g. via max len
    or eos token).

    The SingleStepOutputProcessor is specialized to the case where the model
    emits at most a single token per invocation, which precludes configurations
    such as speculative decoding or multi-step decoding. This enables beam
    search sampling, which requires forking/finishing/freeing sequences in a way
    that is currently difficult to schedule multiple steps ahead of time.
    """

76
77
78
    def __init__(self, scheduler_config: SchedulerConfig,
                 detokenizer: Detokenizer, scheduler: List[Scheduler],
                 seq_counter: Counter, stop_checker: StopChecker):
79
80
81
82
83
84
85
        self.scheduler_config = scheduler_config
        self.detokenizer = detokenizer
        self.scheduler = scheduler
        self.seq_counter = seq_counter
        self.stop_checker = stop_checker

    def process_outputs(self, sequence_group: SequenceGroup,
86
87
                        outputs: List[SequenceGroupOutput],
                        is_async: bool) -> None:
88
89
90
91
92
        """Append all new tokens to sequences in the sequence group. Fork any
        surviving beam candidates; free any unsurviving ones.

        Invokes detokenizer to detokenize new tokens, and also marks sequences
        as finished if they meet stop conditions.
93
94
95
96
97
98
        
        is_async - Indicates whether this postprocessor runs in 
            parallel with the GPU forward pass and is processing 
            tokens from the previous step. If this is true, then
            no tokens need to be appended since it is already done
            externally (before the next schedule() call)
99
100
101
        """
        assert (len(outputs) == 1
                ), f"{type(self)} does not support multiple outputs per step"
102
103
        return self._process_sequence_group_outputs(sequence_group, outputs[0],
                                                    is_async)
104

105
106
    def process_prompt_logprob(self, seq_group: SequenceGroup,
                               outputs: List[SequenceGroupOutput]) -> None:
107
108
109
110
        """Process prompt logprobs associated with one step of a single-step-
        scheduled computation.
        
        Args:
111
112
113
114
115
          seq_group: the output is associated with this
              [`SequenceGroup`][vllm.sequence.SequenceGroup]
          outputs: the
              [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]
              for a single scheduler step
116
        """
117
        assert len(outputs) == 1, "Single step should only have 1 output."
118
        output = outputs[0]
119
        assert isinstance(output, CompletionSequenceGroupOutput)
120
        single_step_process_prompt_logprob(self, seq_group, output)
121

122
    def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
123
124
                                        outputs: SequenceGroupOutput,
                                        is_async: bool) -> None:
125
        sampling_params = seq_group.sampling_params
126
127
128
129

        sample = outputs.samples[0]
        seq = seq_group.first_seq
        if not is_async:
130
131
            seq.append_token_id(sample.output_token, sample.logprobs,
                                sample.output_embed)
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        if sampling_params.detokenize and self.detokenizer:
            new_char_count = self.detokenizer.decode_sequence_inplace(
                seq, sampling_params)
        else:
            new_char_count = 0
        self.stop_checker.maybe_stop_sequence(
            seq,
            new_char_count,
            sampling_params,
            lora_req=seq_group.lora_request,
        )
        if seq.is_finished():
            for scheduler in self.scheduler:
                scheduler.free_seq(seq)