stop_checker.py 7.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import Callable, List, Optional, Tuple
4

5
from vllm.lora.request import LoRARequest
6
7
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
8
from vllm.transformers_utils.tokenizer import AnyTokenizer
guanyu1's avatar
guanyu1 committed
9
import  os
10
11
12
13
14
15
16
17
18

class StopChecker:
    """LLMEngine helper class which separates out the logic involving stop
    checking. This checks things such as: whether the eos token was emitted,
    whether the max_tokens has been consumed, whether a stop string has been
    emitted, or if we have exceeded the max model len.
    """

    def __init__(self, max_model_len: int,
19
                 get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
20
21
        # Do not use it directly, but use `self._get_max_model_len`.
        self._max_model_len = max_model_len
22
        self.get_tokenizer_for_seq = get_tokenizer_for_seq
guanyu1's avatar
guanyu1 committed
23
24
        self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
        
25

26
27
28
29
30
31
32
33
34
35
36
37
38
    def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
        if lora_req and lora_req.long_lora_max_len:
            return lora_req.long_lora_max_len
        else:
            return self._max_model_len

    def maybe_stop_sequence(
        self,
        seq: Sequence,
        new_char_count: int,
        sampling_params: SamplingParams,
        lora_req: Optional[LoRARequest] = None,
    ) -> None:
39
40
41
42
43
44
45
46
        """Stop the finished sequences.

       new_char_count is the number of chars added to the
           sequence's output text for the newly generated token
        """

        # Check if the minimum number of tokens has been generated yet;
        # skip the stop string/token checks if not
guanyu1's avatar
guanyu1 committed
47
48
49
50
        if self.zero_overhead:
            if seq.zero_overhead_get_output_len() < sampling_params.min_tokens:
                return
        #new char count的 暂时未修改逻辑
51
        # Check if the sequence has generated the EOS token.
guanyu1's avatar
guanyu1 committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
            if ((not sampling_params.ignore_eos)
                    and seq.zero_overhead_get_last_token_id() == seq.eos_token_id):
                # Remove the last EOS token unless explicitly specified
                # This prevents unintended exposure of the EOS token
                if new_char_count and (
                        not sampling_params.include_stop_str_in_output):
                    seq.output_text = seq.output_text[:-new_char_count]
                seq.status = SequenceStatus.FINISHED_STOPPED
                return

            # Check if a stop token was encountered.
            # This assumes a single token produced per step.
            last_token_id = seq.zero_overhead_get_last_token_id()
            if last_token_id in (sampling_params.stop_token_ids or ()):
                if new_char_count and (
                        not sampling_params.include_stop_str_in_output):
                    # Remove last token
                    seq.output_text = seq.output_text[:-new_char_count]
                seq.status = SequenceStatus.FINISHED_STOPPED
                seq.stop_reason = last_token_id
                return

            # Check if any stop strings are matched.
            stop = self.check_stop_strings(
                seq.output_text, new_char_count, sampling_params.stop,
                sampling_params.include_stop_str_in_output)
            if stop is not None:
                stop_str, truncate_to = stop
                if truncate_to != -1:
                    seq.output_text = seq.output_text[:truncate_to]
                seq.status = SequenceStatus.FINISHED_STOPPED
                seq.stop_reason = stop_str
                return

            # Check if the sequence has reached max_model_len.
            if seq.zero_overhead_get_len() > self._get_max_model_len(lora_req):
                seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
                return
        
            # Check if the sequence has reached max_tokens.
            if seq.zero_overhead_get_output_len() >= sampling_params.max_tokens:
                seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
                return
            # Check if the minimum number of tokens has been generated yet;
            # skip the stop string/token checks if not
        else:
            if seq.get_output_len() < sampling_params.min_tokens:
                return

            # Check if the sequence has generated the EOS token.
            if ((not sampling_params.ignore_eos)
                    and seq.get_last_token_id() == seq.eos_token_id):
                # Remove the last EOS token unless explicitly specified
                # This prevents unintended exposure of the EOS token
                if new_char_count and (
                        not sampling_params.include_stop_str_in_output):
                    seq.output_text = seq.output_text[:-new_char_count]
                seq.status = SequenceStatus.FINISHED_STOPPED
                return

            # Check if a stop token was encountered.
            # This assumes a single token produced per step.
            last_token_id = seq.get_last_token_id()
            if last_token_id in (sampling_params.stop_token_ids or ()):
                if new_char_count and (
                        not sampling_params.include_stop_str_in_output):
                    # Remove last token
                    seq.output_text = seq.output_text[:-new_char_count]
                seq.status = SequenceStatus.FINISHED_STOPPED
                seq.stop_reason = last_token_id
                return

            # Check if any stop strings are matched.
            stop = self.check_stop_strings(
                seq.output_text, new_char_count, sampling_params.stop,
                sampling_params.include_stop_str_in_output)
            if stop is not None:
                stop_str, truncate_to = stop
                if truncate_to != -1:
                    seq.output_text = seq.output_text[:truncate_to]
                seq.status = SequenceStatus.FINISHED_STOPPED
                seq.stop_reason = stop_str
                return

            # Check if the sequence has reached max_model_len.
            if seq.get_len() > self._get_max_model_len(lora_req):
                seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
                return

            # Check if the sequence has reached max_tokens.
            if seq.get_output_len() == sampling_params.max_tokens:
                seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
                return
145
146

    @staticmethod
147
148
149
150
151
152
    def check_stop_strings(
        output_text: str,
        new_char_count: int,
        stop: List[str],
        include_in_output: bool,
    ) -> Optional[Tuple[str, int]]:
153
154
155
        """Check if any stop strings are matched and truncate sequence
        output text accordingly.

156
157
158
159
160
        Returns tuple (stop_string, offset) if matched or else None.

        Where stop_string is the matched stop string and offset is the
        length to which output_text should be truncated, or -1 for no
        truncation.
161
        """
162
        if not new_char_count or not stop:
163
164
            return None

165
        for stop_str in stop:
166
167
            stop_string_len = len(stop_str)
            # Avoid searching already-searched text.
168
169
            stop_index = output_text.find(stop_str,
                                          -new_char_count - stop_string_len)
170
171
172
            if stop_index == -1:
                continue

173
            if include_in_output:
174
175
                # Truncate to end of stop string.
                stop_index += stop_string_len
176
                if stop_index >= len(output_text):
177
                    # No truncation required.
178
                    return stop_str, -1
179
180
181

            # Truncate the output text to either the beginning
            # or end of the stop string.
182
            return stop_str, stop_index
183
        return None