multi_step.py 7.57 KB
Newer Older
1
import functools
2
from typing import Callable, List
3
4
5
6

from vllm.core.scheduler import Scheduler
from vllm.engine.output_processor.interfaces import (
    SequenceGroupOutputProcessor)
7
8
from vllm.engine.output_processor.single_step import (
    single_step_process_prompt_logprob)
9
10
11
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
12
13
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup,
                           SequenceGroupOutput, SequenceOutput, SequenceStatus)
14
from vllm.transformers_utils.detokenizer import Detokenizer
15
from vllm.transformers_utils.tokenizer import AnyTokenizer
16
from vllm.utils import Counter
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

logger = init_logger(__name__)


class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
    """SequenceGroupOutputProcessor which handles logic related to
    detokenization and stopping conditions. It specializes to "multi-step
    decoding", where vLLM's worker may generate multiple tokens per invocation.
    This is currently mutually exclusive with advanced sampling techniques like
    beam search, which motivates the separation of this logic from the single
    step output processor.

    This class is responsible for things such as correctly appending all new
    token ids to their sequence, detokenizing new token ids, truncating new
    output tokens after an eos token, and correctly handling the case where the
    number of new output tokens per sequence differs in a single batch.
    """

    def __init__(
        self,
        detokenizer: Detokenizer,
38
        scheduler: List[Scheduler],
39
        seq_counter: Counter,
40
        get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
41
42
43
44
45
46
47
48
        stop_checker: StopChecker,
    ):
        self.detokenizer = detokenizer
        self.scheduler = scheduler
        self.seq_counter = seq_counter
        self.get_tokenizer_for_seq = get_tokenizer_for_seq
        self.stop_checker = stop_checker

49
50
    def process_prompt_logprob(self, seq_group: SequenceGroup,
                               outputs: List[SequenceGroupOutput]) -> None:
51
52
53
54
55
56
57
58
59
60
        """Process prompt logprobs associated with each step of a multi-step-
        scheduled computation.

        Args:
          seq_group: the outputs are associated with this :class:`SequenceGroup`
          outputs: the :class:`SequenceGroupOutput`s for all scheduler steps
        """
        for output in outputs:
            # Concatenate single-step prompt logprob processing results.
            single_step_process_prompt_logprob(self, seq_group, output)
61
62
63
64

    @staticmethod
    @functools.lru_cache()
    def _log_prompt_logprob_unsupported_warning_once():
65
66
67
68
        logger.warning(
            "Prompt logprob is not supported by multi step workers. "
            "(e.g., speculative decode uses multi step workers).")

69
70
71
72
    def process_outputs(self,
                        sequence_group: SequenceGroup,
                        outputs: List[SequenceGroupOutput],
                        is_async: bool = False) -> None:
73
74
75
76
77
        """Append new tokens in the outputs to sequences in the sequence group.

        This only supports sequence groups of size 1. It supports greater than
        one new token per sequence.

78
79
80
81
82
83
84
85
86
        This applies logic like stop condition checking and detokenization.
        It also handles cases where there are tokens emitted after 
        the EOS token.

        is_async - Indicates whether this postprocessor runs in 
            parallel with the GPU forward pass and is processing 
            tokens from the previous step. If this is true, then
            no tokens need to be appended since it is already done
            externally (before the next schedule() call)
87
        """
88
89
90
        # Sequences can be in RUNNING or FINISHED_ABORTED state
        # once scheduled, as a sequence is moved to FINSIHED_ABORTED
        # if a client disconnects from the api server.
91
        seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING)
92
93
94
        if seqs is None:
            seqs = sequence_group.get_seqs(
                status=SequenceStatus.FINISHED_ABORTED)
95

96
        assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences"
97
98
99
100
        assert len(seqs) == 1, (
            "Beam search not supported in multi-step decoding.")
        seq = seqs[0]

101
102
103
104
105
106
107
108
109
110
111
112
        if is_async:
            # Async case: We process tokens one by one. Here, we know the token
            # was already appended, so we only need to do the rest of the
            # postprocessor: Detokenization + stopping logic
            self._process_decode_and_stop(seq, sequence_group.sampling_params)
        else:
            # Standard multi-step case

            # Since there's only one sequence per sequence group,
            # we can take the first sample.
            samples = [output.samples[0] for output in outputs]

113
            # entries in sample tokens may be invalid (eg. due to spec decode
114
115
            # rejecting tokens).
            valid_samples = [
116
117
                sample for sample in samples
                if sample.output_token != VLLM_INVALID_TOKEN_ID
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
            ]
            assert valid_samples

            self._process_seq_outputs(seq, valid_samples,
                                      sequence_group.sampling_params)

    def _process_decode_and_stop(self, seq: Sequence,
                                 sampling_params: SamplingParams) -> None:
        new_char_count = 0
        if sampling_params.detokenize:
            new_char_count = self.detokenizer.decode_sequence_inplace(
                seq, sampling_params)

        # TODO(sang): Support lora.
        self.stop_checker.maybe_stop_sequence(
            seq,
            new_char_count=new_char_count,
            sampling_params=sampling_params,
        )
137
138
139
140
141

    def _process_seq_outputs(self, seq: Sequence,
                             valid_samples: List[SequenceOutput],
                             sampling_params: SamplingParams) -> None:
        output_token_ids = [sample.output_token for sample in valid_samples]
142
        output_logprobs = [sample.logprobs for sample in valid_samples]
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

        # Truncate to max_tokens if necessary.
        remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
                                                         len(output_token_ids))
        if remaining_tokens < 0:
            valid_samples = valid_samples[:remaining_tokens]
            output_token_ids = output_token_ids[:remaining_tokens]

        # Truncate any tokens after EOS. This is required as spec decode
        # generates a fixed number of tokens without evaluating stopping
        # conditions within the block. This can cause an eos token to be
        # unintentionally ignored.
        if not sampling_params.ignore_eos:
            eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
            # Avoiding .index calls as exception throwing in the happy path
            # is expensive.
            for i in range(len(output_token_ids)):
                if output_token_ids[i] == eos_token_id:
                    output_token_ids = output_token_ids[:i + 1]
                    valid_samples = valid_samples[:i + 1]
                    break

        # Incrementally append tokens to the sequence, as if we had only one new
        # token.
167
168
        for output_token_id, output_logprob in zip(output_token_ids,
                                                   output_logprobs):
169
170
            seq.append_token_id(
                token_id=output_token_id,
171
                logprobs=output_logprob,
172
173
            )

174
            self._process_decode_and_stop(seq, sampling_params)
175
176
177

            if seq.is_finished():
                break