stop_checker.py 5.11 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
from typing import Callable, List, Optional, Tuple
5

6
from vllm.lora.request import LoRARequest
7
8
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceStatus
9
from vllm.transformers_utils.tokenizer import AnyTokenizer
10

11
12
13
14
15
16
17
18
19

class StopChecker:
    """LLMEngine helper class which separates out the logic involving stop
    checking. This checks things such as: whether the eos token was emitted,
    whether the max_tokens has been consumed, whether a stop string has been
    emitted, or if we have exceeded the max model len.
    """

    def __init__(self, max_model_len: int,
20
                 get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
21
22
        # Do not use it directly, but use `self._get_max_model_len`.
        self._max_model_len = max_model_len
23
        self.get_tokenizer_for_seq = get_tokenizer_for_seq
guanyu1's avatar
guanyu1 committed
24
        self.zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
25

26
27
28
29
30
31
32
33
34
35
36
37
38
    def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
        if lora_req and lora_req.long_lora_max_len:
            return lora_req.long_lora_max_len
        else:
            return self._max_model_len

    def maybe_stop_sequence(
        self,
        seq: Sequence,
        new_char_count: int,
        sampling_params: SamplingParams,
        lora_req: Optional[LoRARequest] = None,
    ) -> None:
39
40
41
42
43
44
45
46
        """Stop the finished sequences.

       new_char_count is the number of chars added to the
           sequence's output text for the newly generated token
        """

        # Check if the minimum number of tokens has been generated yet;
        # skip the stop string/token checks if not
47
48
49
        if seq.get_output_len(self.zero_overhead) < sampling_params.min_tokens:
            return

50
        # Check if the sequence has generated the EOS token.
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        if ((not sampling_params.ignore_eos)
                and seq.get_last_token_id(self.zero_overhead) == seq.eos_token_id):
            # Remove the last EOS token unless explicitly specified
            # This prevents unintended exposure of the EOS token
            if new_char_count and (
                    not sampling_params.include_stop_str_in_output):
                seq.output_text = seq.output_text[:-new_char_count]
            seq.status = SequenceStatus.FINISHED_STOPPED
            return

        # Check if a stop token was encountered.
        # This assumes a single token produced per step.
        last_token_id = seq.get_last_token_id(self.zero_overhead)
        if last_token_id in (sampling_params.stop_token_ids or ()):
            if new_char_count and (
                    not sampling_params.include_stop_str_in_output):
                # Remove last token
                seq.output_text = seq.output_text[:-new_char_count]
            seq.status = SequenceStatus.FINISHED_STOPPED
            seq.stop_reason = last_token_id
            return

        # Check if any stop strings are matched.
        stop = self.check_stop_strings(
            seq.output_text, new_char_count, sampling_params.stop,
            sampling_params.include_stop_str_in_output)
        if stop is not None:
            stop_str, truncate_to = stop
            if truncate_to != -1:
                seq.output_text = seq.output_text[:truncate_to]
            seq.status = SequenceStatus.FINISHED_STOPPED
            seq.stop_reason = stop_str
            return

        # Check if the sequence has reached max_model_len.
        if seq.get_len(self.zero_overhead) > self._get_max_model_len(lora_req):
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return

        # Check if the sequence has reached max_tokens.
        if seq.get_output_len(self.zero_overhead) == sampling_params.max_tokens:
            seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
            return
94
95

    @staticmethod
96
97
98
99
100
101
    def check_stop_strings(
        output_text: str,
        new_char_count: int,
        stop: List[str],
        include_in_output: bool,
    ) -> Optional[Tuple[str, int]]:
102
103
104
        """Check if any stop strings are matched and truncate sequence
        output text accordingly.

105
106
107
108
109
        Returns tuple (stop_string, offset) if matched or else None.

        Where stop_string is the matched stop string and offset is the
        length to which output_text should be truncated, or -1 for no
        truncation.
110
        """
111
        if not new_char_count or not stop:
112
113
            return None

114
        for stop_str in stop:
115
116
            stop_string_len = len(stop_str)
            # Avoid searching already-searched text.
117
118
            stop_index = output_text.find(stop_str,
                                          -new_char_count - stop_string_len)
119
120
121
            if stop_index == -1:
                continue

122
            if include_in_output:
123
124
                # Truncate to end of stop string.
                stop_index += stop_string_len
125
                if stop_index >= len(output_text):
126
                    # No truncation required.
127
                    return stop_str, -1
128
129
130

            # Truncate the output text to either the beginning
            # or end of the stop string.
131
            return stop_str, stop_index
132
        return None