detokenizer_utils.py 7.81 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
from vllm.tokenizers import TokenizerLike
6
7


8
def _replace_none_with_empty(tokens: list[str | None]):
9
10
11
12
13
14
    for i, token in enumerate(tokens):
        if token is None:
            tokens[i] = ""


def _convert_tokens_to_string_with_added_encoders(
15
    tokenizer: TokenizerLike,
16
    output_tokens: list[str],
17
18
    skip_special_tokens: bool,
    spaces_between_special_tokens: bool,
19
    mode: str,
20
21
22
23
24
25
) -> str:
    # Adapted from
    # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
    # NOTE(woosuk): The following code is slow because it runs a for loop over
    # the output_tokens. In Python, running a for loop over a list can be slow
    # even when the loop body is very simple.
26
27
28
    # Performance improvements: avoid repeated attribute and function lookups;
    # localize frequently used objects;

29
30
    sub_texts: list[str] = []
    current_sub_text: list[str] = []
31
32
    convert_tokens_to_string = tokenizer.convert_tokens_to_string
    added_vocab_set = set(tokenizer.get_added_vocab())
33
    if mode != "cpm":
34
35
36
        all_special_tokens = (
            set(tokenizer.all_special_tokens) if skip_special_tokens else ()
        )
37
38
    else:
        all_special_tokens = tokenizer._special_token_set
39

40
    for token in output_tokens:
41
42
        # Use precomputed set for skip-special check
        if token in all_special_tokens:
43
            continue
44
        if token in added_vocab_set:
45
            if current_sub_text:
46
47
                sub_texts.append(convert_tokens_to_string(current_sub_text))
                current_sub_text.clear()
48
49
50
51
            sub_texts.append(token)
        else:
            current_sub_text.append(token)
    if current_sub_text:
52
        if mode != "cpm":
53
            sub_texts.append(convert_tokens_to_string(current_sub_text))
54
        else:
55
            sub_texts = tokenizer.decode(current_sub_text)
56
57
    if spaces_between_special_tokens:
        return " ".join(sub_texts)
58
    return "".join(sub_texts)
59
60
61
62
63
64
65
66


# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5


def convert_prompt_ids_to_tokens(
67
    tokenizer: TokenizerLike,
68
    prompt_ids: list[int],
69
    skip_special_tokens: bool = False,
70
) -> tuple[list[str], int, int]:
71
72
73
74
75
76
77
78
79
    """Converts the prompt ids to tokens and returns the tokens and offsets
    for incremental detokenization.

    Note that not all tokens are converted to strings. Only the tokens that
    are necessary for incremental detokenization are converted to strings.
    """
    # We do not need to convert the whole prompt to tokens.
    # Offset a little more in case we have special tokens.
    new_tokens = tokenizer.convert_ids_to_tokens(
80
81
82
        prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2 :],
        skip_special_tokens=skip_special_tokens,
    )
83
    read_offset = len(new_tokens)
84
    prefix_offset = max(read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
85
86
87
88
89
    # This is required to guard against out-of-vocab prompt token ids
    _replace_none_with_empty(new_tokens)  # type: ignore[arg-type]
    return new_tokens, prefix_offset, read_offset


90
def convert_ids_list_to_tokens(
91
    tokenizer: TokenizerLike,
92
93
    token_ids: list[int],
) -> list[str]:
94
95
96
97
98
99
100
101
    """Detokenize the input ids individually.

    Args:
      tokenizer: tokenizer used by model under test
      token_ids: convert these tokens (Python list form)

    Returns:
      Python list of token string representations
102

103
    """
104
105
    token_str_lst = []
    for token_id in token_ids:
106
107
        # use default skip_special_tokens.
        token_str = tokenizer.decode([token_id])
108
109
110
        if token_str is None:
            token_str = ""
        token_str_lst.append(token_str)
111
112
113
    return token_str_lst


114
115
116
117
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
118
    tokenizer: TokenizerLike,
119
    all_input_ids: list[int],
120
    prev_tokens: list[str] | None,
121
122
123
124
    prefix_offset: int,
    read_offset: int,
    skip_special_tokens: bool = False,
    spaces_between_special_tokens: bool = True,
125
    mode: str = "cpm",
126
) -> tuple[list[str], str, int, int]:
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    """Detokenizes the input ids incrementally and returns the new tokens
    and the new text.

    If `prev_tokens` is None, this function will convert the input ids to
    tokens and return the tokens and the new text. Otherwise, it will return the
    new tokens and the new text.

    This function will also return the new prefix offset and the new read
    offset to be used in the next iteration.

    The offsets are necessary to defeat cleanup algorithms in the decode which
    decide to add a space or not depending on the surrounding ids.

    Args:
        tokenizer: The tokenizer to use.
        all_input_ids: The input ids. The last id is the new token id.
        prev_tokens: The previous tokens. If None, this function will convert
            the input ids to tokens and return the tokens and the new text.
        prefix_offset: The prefix offset.
        read_offset: The read offset.
        skip_special_tokens: Whether to skip special tokens.
        spaces_between_special_tokens: Whether to add spaces between special
            tokens.
    """
    new_token_id = all_input_ids[-1]
    # This is the first iteration for this sequence
    is_first_iter = prev_tokens is None
    if is_first_iter:
155
156
157
        (prev_tokens, prefix_offset, read_offset) = convert_prompt_ids_to_tokens(
            tokenizer, all_input_ids[:-1], skip_special_tokens=skip_special_tokens
        )
158
159
160
    assert prev_tokens is not None

    # If the new token id is out of bounds, return an empty string.
161
162
163
164
165
    if mode == "cpm":
        vocab_size = tokenizer.vocab_size
    else:
        vocab_size = len(tokenizer)
    if 0 <= new_token_id < vocab_size:
166
167
        # Put new_token_id in a list so skip_special_tokens is respected
        new_tokens = tokenizer.convert_ids_to_tokens(
168
169
            [new_token_id], skip_special_tokens=skip_special_tokens
        )
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        if isinstance(new_tokens, str):
            new_tokens = [new_tokens]
    else:
        new_tokens = [""]
    output_tokens = prev_tokens + new_tokens

    # If this is the first iteration, return all tokens.
    if is_first_iter:
        new_tokens = output_tokens

    # The prefix text is necessary only to defeat cleanup algorithms in
    # the decode which decide to add a space or not depending on the
    # surrounding ids.
    if tokenizer.is_fast or not tokenizer.get_added_vocab():
        prefix_text = tokenizer.convert_tokens_to_string(
185
186
187
            output_tokens[prefix_offset:read_offset]
        )
        new_text = tokenizer.convert_tokens_to_string(output_tokens[prefix_offset:])
188
189
190
191
192
193
    else:
        prefix_text = _convert_tokens_to_string_with_added_encoders(
            tokenizer,
            output_tokens[prefix_offset:read_offset],
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
194
            mode=mode,
195
196
197
198
199
200
        )
        new_text = _convert_tokens_to_string_with_added_encoders(
            tokenizer,
            output_tokens[prefix_offset:],
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
201
            mode=mode,
202
203
204
205
206
207
208
209
210
        )

    if len(new_text) <= len(prefix_text) or new_text.endswith("�"):
        # utf-8 char at the end means it's a potential unfinished byte sequence
        # from byte fallback tokenization.
        # If it's in the middle, it's probably a real invalid id generated
        # by the model
        return new_tokens, "", prefix_offset, read_offset

211
    new_text = new_text[len(prefix_text) :]
212
    return new_tokens, new_text, read_offset, len(output_tokens)