multi_step.py 9.28 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import functools
5
from typing import Callable, List, cast
6
7
8
9

from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
10
11
from vllm.engine.output_processor.single_step import (
    single_step_process_prompt_logprob)
12
13
14
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
15
16
17
18
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
                           CompletionSequenceGroupOutput, Sequence,
                           SequenceGroup, SequenceGroupOutput, SequenceOutput,
                           SequenceStatus)
19
from vllm.transformers_utils.detokenizer import Detokenizer
20
from vllm.transformers_utils.tokenizer import AnyTokenizer
21
from vllm.utils import Counter
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

logger = init_logger(__name__)


class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
    """SequenceGroupOutputProcessor which handles logic related to
    detokenization and stopping conditions. It specializes to "multi-step
    decoding", where vLLM's worker may generate multiple tokens per invocation.
    This is currently mutually exclusive with advanced sampling techniques like
    beam search, which motivates the separation of this logic from the single
    step output processor.

    This class is responsible for things such as correctly appending all new
    token ids to their sequence, detokenizing new token ids, truncating new
    output tokens after an eos token, and correctly handling the case where the
    number of new output tokens per sequence differs in a single batch.
    """

    def __init__(
        self,
        detokenizer: Detokenizer,
43
        scheduler: List[Scheduler],
44
        seq_counter: Counter,
45
        get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
46
47
48
49
50
51
52
53
        stop_checker: StopChecker,
    ):
        self.detokenizer = detokenizer
        self.scheduler = scheduler
        self.seq_counter = seq_counter
        self.get_tokenizer_for_seq = get_tokenizer_for_seq
        self.stop_checker = stop_checker

54
55
    def process_prompt_logprob(self, seq_group: SequenceGroup,
                               outputs: List[SequenceGroupOutput]) -> None:
56
57
58
59
        """Process prompt logprobs associated with each step of a multi-step-
        scheduled computation.

        Args:
60
61
62
63
64
          seq_group: the outputs are associated with this
              [`SequenceGroup`][vllm.sequence.SequenceGroup]
          outputs: the
              [`SequenceGroupOutput`][vllm.sequence.SequenceGroupOutput]s
              for all scheduler steps
65
66
67
        """
        for output in outputs:
            # Concatenate single-step prompt logprob processing results.
68
            assert isinstance(output, CompletionSequenceGroupOutput)
69
            single_step_process_prompt_logprob(self, seq_group, output)
70
71

    @staticmethod
72
    @functools.lru_cache
73
    def _log_prompt_logprob_unsupported_warning_once():
74
        # Reminder: Please update docs/features/compatibility_matrix.md
75
        # If the feature combo become valid
76
77
78
79
        logger.warning(
            "Prompt logprob is not supported by multi step workers. "
            "(e.g., speculative decode uses multi step workers).")

80
81
82
    def process_outputs(self,
                        sequence_group: SequenceGroup,
                        outputs: List[SequenceGroupOutput],
83
                        is_async: bool = False) -> None:
84
85
86
87
88
        """Append new tokens in the outputs to sequences in the sequence group.

        This only supports sequence groups of size 1. It supports greater than
        one new token per sequence.

89
90
91
92
93
94
95
96
97
        This applies logic like stop condition checking and detokenization.
        It also handles cases where there are tokens emitted after 
        the EOS token.

        is_async - Indicates whether this postprocessor runs in 
            parallel with the GPU forward pass and is processing 
            tokens from the previous step. If this is true, then
            no tokens need to be appended since it is already done
            externally (before the next schedule() call)
98
        """
99
        # Sequences can be in RUNNING or FINISHED_ABORTED state
100
        # once scheduled, as a sequence is moved to FINISHED_ABORTED
101
        # if a client disconnects from the api server.
102
        seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
103
104
105
        if seqs is None:
            seqs = sequence_group.get_seqs(
                status=SequenceStatus.FINISHED_ABORTED)
106

107
108
109
110
111
        for output in outputs:
            if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID:
                sequence_group.metrics.spec_token_acceptance_counts[
                    output.step_index] += 1

112
        assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
113
114
115
        assert len(seqs) == 1, (
            "Beam search not supported in multi-step decoding.")
        seq = seqs[0]
116
        seq_id = seq.seq_id
117
118
119
120
121
122
123
124
125
126
127
128
        # This method is defined in the more generic
        # SequenceGroupOutputProcessor, but here we assume that the outputs are
        # of a more specific type.
        assert all([
            isinstance(output, CompletionSequenceGroupOutput)
            for output in outputs
        ])
        compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs)
        assert all([
            seq_id == output.samples[0].parent_seq_id
            for output in compl_outputs
        ])
129

130
131
132
133
134
135
136
137
138
139
        if is_async:
            # Async case: We process tokens one by one. Here, we know the token
            # was already appended, so we only need to do the rest of the
            # postprocessor: Detokenization + stopping logic
            self._process_decode_and_stop(seq, sequence_group.sampling_params)
        else:
            # Standard multi-step case

            # Since there's only one sequence per sequence group,
            # we can take the first sample.
140
            samples = [output.samples[0] for output in compl_outputs]
141

142
            # entries in sample tokens may be invalid (eg. due to spec decode
143
144
            # rejecting tokens).
            valid_samples = [
145
146
                sample for sample in samples
                if sample.output_token != VLLM_INVALID_TOKEN_ID
147
148
            ]

149
150
151
152
153
            # When both spec-decode and pre-fill chunking are enabled, we
            # don't have guaranteed samples here (e.g. all -1s).
            if valid_samples:
                self._process_seq_outputs(seq, valid_samples,
                                          sequence_group.sampling_params)
154
155
156
157

    def _process_decode_and_stop(self, seq: Sequence,
                                 sampling_params: SamplingParams) -> None:
        new_char_count = 0
158
        if sampling_params.detokenize and self.detokenizer:
159
160
161
162
163
164
165
166
167
            new_char_count = self.detokenizer.decode_sequence_inplace(
                seq, sampling_params)

        # TODO(sang): Support lora.
        self.stop_checker.maybe_stop_sequence(
            seq,
            new_char_count=new_char_count,
            sampling_params=sampling_params,
        )
168
169
170

    def _process_seq_outputs(self, seq: Sequence,
                             valid_samples: List[SequenceOutput],
171
                             sampling_params: SamplingParams) -> None:
172
        output_token_ids = [sample.output_token for sample in valid_samples]
173
        output_logprobs = [sample.logprobs for sample in valid_samples]
174
        output_embeds = [sample.output_embed for sample in valid_samples]
175
176
177
178
179
180
181
182
183
184
185

        # Truncate to max_tokens if necessary.
        remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
                                                         len(output_token_ids))
        if remaining_tokens < 0:
            output_token_ids = output_token_ids[:remaining_tokens]

        # Truncate any tokens after EOS. This is required as spec decode
        # generates a fixed number of tokens without evaluating stopping
        # conditions within the block. This can cause an eos token to be
        # unintentionally ignored.
186
        if not sampling_params.ignore_eos and self.detokenizer:
187
188
189
190
191
192
193
194
            eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
            # Avoiding .index calls as exception throwing in the happy path
            # is expensive.
            for i in range(len(output_token_ids)):
                if output_token_ids[i] == eos_token_id:
                    output_token_ids = output_token_ids[:i + 1]
                    break

195
        is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
196
197
        # Incrementally append tokens to the sequence, as if we had only one new
        # token.
198
199
        for output_token_id, output_logprob, output_embed in zip(
                output_token_ids, output_logprobs, output_embeds):
200
201
            seq.append_token_id(
                token_id=output_token_id,
202
                logprobs=output_logprob,
203
                token_embed=output_embed,
204
205
            )

206
207
208
209
210
211
212
            if is_prefill_sampled_token:
                is_prefill_sampled_token = False
            else:
                # Update num_computed_tokens iff the sampled token is not from
                # a prefill step.
                seq.data.update_num_computed_tokens(1)

213
            self._process_decode_and_stop(seq, sampling_params)
214
215
216

            if seq.is_finished():
                break