multi_step.py 9.12 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import functools
4
from typing import Callable, List, cast
5
6
7
8

from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
9
10
from vllm.engine.output_processor.single_step import (
    single_step_process_prompt_logprob)
11
12
13
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
14
15
16
17
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
                           CompletionSequenceGroupOutput, Sequence,
                           SequenceGroup, SequenceGroupOutput, SequenceOutput,
                           SequenceStatus)
18
from vllm.transformers_utils.detokenizer import Detokenizer
19
from vllm.transformers_utils.tokenizer import AnyTokenizer
20
from vllm.utils import Counter
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

logger = init_logger(__name__)


class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
    """SequenceGroupOutputProcessor which handles logic related to
    detokenization and stopping conditions. It specializes to "multi-step
    decoding", where vLLM's worker may generate multiple tokens per invocation.
    This is currently mutually exclusive with advanced sampling techniques like
    beam search, which motivates the separation of this logic from the single
    step output processor.

    This class is responsible for things such as correctly appending all new
    token ids to their sequence, detokenizing new token ids, truncating new
    output tokens after an eos token, and correctly handling the case where the
    number of new output tokens per sequence differs in a single batch.
    """

    def __init__(
        self,
        detokenizer: Detokenizer,
42
        scheduler: List[Scheduler],
43
        seq_counter: Counter,
44
        get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
45
46
47
48
49
50
51
52
        stop_checker: StopChecker,
    ):
        self.detokenizer = detokenizer
        self.scheduler = scheduler
        self.seq_counter = seq_counter
        self.get_tokenizer_for_seq = get_tokenizer_for_seq
        self.stop_checker = stop_checker

53
54
    def process_prompt_logprob(self, seq_group: SequenceGroup,
                               outputs: List[SequenceGroupOutput]) -> None:
55
56
57
58
        """Process prompt logprobs associated with each step of a multi-step-
        scheduled computation.

        Args:
59
60
          seq_group: the outputs are associated with this {class}`SequenceGroup`
          outputs: the {class}`SequenceGroupOutput`s for all scheduler steps
61
62
63
        """
        for output in outputs:
            # Concatenate single-step prompt logprob processing results.
64
            assert isinstance(output, CompletionSequenceGroupOutput)
65
            single_step_process_prompt_logprob(self, seq_group, output)
66
67

    @staticmethod
68
    @functools.lru_cache
69
    def _log_prompt_logprob_unsupported_warning_once():
70
        # Reminder: Please update docs/features/compatibility_matrix.md
71
        # If the feature combo become valid
72
73
74
75
        logger.warning(
            "Prompt logprob is not supported by multi step workers. "
            "(e.g., speculative decode uses multi step workers).")

76
77
78
    def process_outputs(self,
                        sequence_group: SequenceGroup,
                        outputs: List[SequenceGroupOutput],
79
                        is_async: bool = False) -> None:
80
81
82
83
84
        """Append new tokens in the outputs to sequences in the sequence group.

        This only supports sequence groups of size 1. It supports greater than
        one new token per sequence.

85
86
87
88
89
90
91
92
93
        This applies logic like stop condition checking and detokenization.
        It also handles cases where there are tokens emitted after 
        the EOS token.

        is_async - Indicates whether this postprocessor runs in 
            parallel with the GPU forward pass and is processing 
            tokens from the previous step. If this is true, then
            no tokens need to be appended since it is already done
            externally (before the next schedule() call)
94
        """
95
        # Sequences can be in RUNNING or FINISHED_ABORTED state
96
        # once scheduled, as a sequence is moved to FINISHED_ABORTED
97
        # if a client disconnects from the api server.
98
        seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
99
100
101
        if seqs is None:
            seqs = sequence_group.get_seqs(
                status=SequenceStatus.FINISHED_ABORTED)
102

103
104
105
106
107
        for output in outputs:
            if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
                sequence_group.metrics.spec_token_acceptance_counts[
                    output.step_index] += 1

108
        assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
109
110
111
        assert len(seqs) == 1, (
            "Beam search not supported in multi-step decoding.")
        seq = seqs[0]
112
        seq_id = seq.seq_id
113
114
115
116
117
118
119
120
121
122
123
124
        # This method is defined in the more generic
        # SequenceGroupOutputProcessor, but here we assume that the outputs are
        # of a more specific type.
        assert all([
            isinstance(output, CompletionSequenceGroupOutput)
            for output in outputs
        ])
        compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
        assert all([
            seq_id == output.samples[0].parent_seq_id
            for output in compl_outputs
        ])
125

126
127
128
129
130
131
132
133
134
135
        if is_async:
            # Async case: We process tokens one by one. Here, we know the token
            # was already appended, so we only need to do the rest of the
            # postprocessor: Detokenization + stopping logic
            self._process_decode_and_stop(seq, sequence_group.sampling_params)
        else:
            # Standard multi-step case

            # Since there's only one sequence per sequence group,
            # we can take the first sample.
136
            samples = [output.samples[0] for output in compl_outputs]
137

138
            # entries in sample tokens may be invalid (eg. due to spec decode
139
140
            # rejecting tokens).
            valid_samples = [
141
142
                sample for sample in samples
                if sample.output_token != VLLM_INVALID_TOKEN_ID
143
144
            ]

145
146
147
148
149
            # When both spec-decode and pre-fill chunking are enabled, we
            # don't have guaranteed samples here (e.g. all -1s).
            if valid_samples:
                self._process_seq_outputs(seq, valid_samples,
                                          sequence_group.sampling_params)
150
151
152
153

    def _process_decode_and_stop(self, seq: Sequence,
                                 sampling_params: SamplingParams) -> None:
        new_char_count = 0
154
        if sampling_params.detokenize and self.detokenizer:
155
156
157
158
159
160
161
162
163
            new_char_count = self.detokenizer.decode_sequence_inplace(
                seq, sampling_params)

        # TODO(sang): Support lora.
        self.stop_checker.maybe_stop_sequence(
            seq,
            new_char_count=new_char_count,
            sampling_params=sampling_params,
        )
164
165
166

    def _process_seq_outputs(self, seq: Sequence,
                             valid_samples: List[SequenceOutput],
167
                             sampling_params: SamplingParams) -> None:
168
        output_token_ids = [sample.output_token for sample in valid_samples]
169
        output_logprobs = [sample.logprobs for sample in valid_samples]
170
        output_embeds = [sample.output_embed for sample in valid_samples]
171
172
173
174
175
176
177
178
179
180
181

        # Truncate to max_tokens if necessary.
        remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
                                                         len(output_token_ids))
        if remaining_tokens < 0:
            output_token_ids = output_token_ids[:remaining_tokens]

        # Truncate any tokens after EOS. This is required as spec decode
        # generates a fixed number of tokens without evaluating stopping
        # conditions within the block. This can cause an eos token to be
        # unintentionally ignored.
182
        if not sampling_params.ignore_eos and self.detokenizer:
183
184
185
186
187
188
189
190
            eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
            # Avoiding .index calls as exception throwing in the happy path
            # is expensive.
            for i in range(len(output_token_ids)):
                if output_token_ids[i] == eos_token_id:
                    output_token_ids = output_token_ids[:i + 1]
                    break

191
        is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
192
193
        # Incrementally append tokens to the sequence, as if we had only one new
        # token.
194
195
        for output_token_id, output_logprob, output_embed in zip(
                output_token_ids, output_logprobs, output_embeds):
196
197
            seq.append_token_id(
                token_id=output_token_id,
198
                logprobs=output_logprob,
199
                token_embed=output_embed,
200
201
            )

202
203
204
205
206
207
208
            if is_prefill_sampled_token:
                is_prefill_sampled_token = False
            else:
                # Update num_computed_tokens iff the sampled token is not from
                # a prefill step.
                seq.data.update_num_computed_tokens(1)

209
            self._process_decode_and_stop(seq, sampling_params)
210
211
212

            if seq.is_finished():
                break