stop_checker.py 4.95 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Callable, List, Optional, Tuple
4

5
from vllm.lora.request import LoRARequest
6
7
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
8
from vllm.transformers_utils.tokenizer import AnyTokenizer
9
10
11
12
13
14
15
16
17
18


class StopChecker:
    """LLMEngine helper class which separates out the logic involving stop
    checking. This checks things such as: whether the eos token was emitted,
    whether the max_tokens has been consumed, whether a stop string has been
    emitted, or if we have exceeded the max model len.
    """

    def __init__(self, max_model_len: int,
19
                 get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
20
21
        # Do not use it directly, but use `self._get_max_model_len`.
        self._max_model_len = max_model_len
22
23
        self.get_tokenizer_for_seq = get_tokenizer_for_seq

24
25
26
27
28
29
30
31
32
33
34
35
36
    def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
        if lora_req and lora_req.long_lora_max_len:
            return lora_req.long_lora_max_len
        else:
            return self._max_model_len

    def maybe_stop_sequence(
        self,
        seq: Sequence,
        new_char_count: int,
        sampling_params: SamplingParams,
        lora_req: Optional[LoRARequest] = None,
    ) -> None:
37
38
39
40
41
42
43
44
45
46
47
48
49
50
        """Stop the finished sequences.

       new_char_count is the number of chars added to the
           sequence's output text for the newly generated token
        """

        # Check if the minimum number of tokens has been generated yet;
        # skip the stop string/token checks if not
        if seq.get_output_len() < sampling_params.min_tokens:
            return

        # Check if the sequence has generated the EOS token.
        if ((not sampling_params.ignore_eos)
                and seq.get_last_token_id() == seq.eos_token_id):
51
52
53
54
55
            # Remove the last EOS token unless explicitly specified
            # This prevents unintended exposure of the EOS token
            if new_char_count and (
                    not sampling_params.include_stop_str_in_output):
                seq.output_text = seq.output_text[:-new_char_count]
56
57
58
59
60
61
            seq.status = SequenceStatus.FINISHED_STOPPED
            return

        # Check if a stop token was encountered.
        # This assumes a single token produced per step.
        last_token_id = seq.get_last_token_id()
62
        if last_token_id in (sampling_params.stop_token_ids or ()):
63
64
65
66
67
68
69
70
71
            if new_char_count and (
                    not sampling_params.include_stop_str_in_output):
                # Remove last token
                seq.output_text = seq.output_text[:-new_char_count]
            seq.status = SequenceStatus.FINISHED_STOPPED
            seq.stop_reason = last_token_id
            return

        # Check if any stop strings are matched.
72
73
74
75
76
77
78
        stop = self.check_stop_strings(
            seq.output_text, new_char_count, sampling_params.stop,
            sampling_params.include_stop_str_in_output)
        if stop is not None:
            stop_str, truncate_to = stop
            if truncate_to != -1:
                seq.output_text = seq.output_text[:truncate_to]
79
80
81
82
83
            seq.status = SequenceStatus.FINISHED_STOPPED
            seq.stop_reason = stop_str
            return

        # Check if the sequence has reached max_model_len.
84
        if seq.get_len() > self._get_max_model_len(lora_req):
85
86
87
88
89
90
91
92
93
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has reached max_tokens.
        if seq.get_output_len() == sampling_params.max_tokens:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

    @staticmethod
94
95
96
97
98
99
    def check_stop_strings(
        output_text: str,
        new_char_count: int,
        stop: List[str],
        include_in_output: bool,
    ) -> Optional[Tuple[str, int]]:
100
101
102
        """Check if any stop strings are matched and truncate sequence
        output text accordingly.

103
104
105
106
107
        Returns tuple (stop_string, offset) if matched or else None.

        Where stop_string is the matched stop string and offset is the
        length to which output_text should be truncated, or -1 for no
        truncation.
108
        """
109
        if not new_char_count or not stop:
110
111
            return None

112
        for stop_str in stop:
113
114
            stop_string_len = len(stop_str)
            # Avoid searching already-searched text.
115
            stop_index = output_text.find(stop_str,
116
                                          1 - new_char_count - stop_string_len)
117
118
119
            if stop_index == -1:
                continue

120
            if include_in_output:
121
122
                # Truncate to end of stop string.
                stop_index += stop_string_len
123
                if stop_index >= len(output_text):
124
                    # No truncation required.
125
                    return stop_str, -1
126
127
128

            # Truncate the output text to either the beginning
            # or end of the stop string.
129
            return stop_str, stop_index
130
        return None