tokenizer.py 10.8 KB
Newer Older
1
from typing import List, Optional, Tuple, Union
2

3
from transformers import (AutoTokenizer, PreTrainedTokenizer,
4
5
                          PreTrainedTokenizerFast)

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.logger import init_logger
7
from vllm.lora.request import LoRARequest
8
from vllm.transformers_utils.tokenizers import *
9
from vllm.utils import make_async
10
11
12

logger = init_logger(__name__)

13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def get_cached_tokenizer(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    """Get tokenizer with cached properties.

    This will patch the tokenizer object in place.

    By default, transformers will recompute multiple tokenizer properties
    each time they are called, leading to a significant slowdown. This
    function caches these properties for faster access."""

    tokenizer_all_special_ids = set(tokenizer.all_special_ids)
    tokenizer_all_special_tokens_extended = (
        tokenizer.all_special_tokens_extended)
    tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
29
    tokenizer_len = len(tokenizer)
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

    class CachedTokenizer(tokenizer.__class__):

        @property
        def all_special_ids(self):
            return tokenizer_all_special_ids

        @property
        def all_special_tokens(self):
            return tokenizer_all_special_tokens

        @property
        def all_special_tokens_extended(self):
            return tokenizer_all_special_tokens_extended

45
46
47
        def __len__(self):
            return tokenizer_len

48
49
50
51
52
53
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

    tokenizer.__class__ = CachedTokenizer
    return tokenizer


54
def get_tokenizer(
55
    tokenizer_name: str,
56
    *args,
57
    tokenizer_mode: str = "auto",
58
    trust_remote_code: bool = False,
59
    tokenizer_revision: Optional[str] = None,
60
61
    **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
62
    """Gets a tokenizer for the given model name via Huggingface."""
63
64
65
66
67
68
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

69
    try:
70
71
72
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            *args,
73
            trust_remote_code=trust_remote_code,
74
            tokenizer_revision=tokenizer_revision,
75
76
77
78
            **kwargs)
    except ValueError as e:
        # If the error pertains to the tokenizer class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
79
        if (not trust_remote_code and
80
81
82
83
84
            ("does not exist or is not currently imported." in str(e)
             or "requires you to execute the tokenizer file" in str(e))):
            err_msg = (
                "Failed to load the tokenizer. If the tokenizer is a custom "
                "tokenizer not yet available in the HuggingFace transformers "
85
86
                "library, consider setting `trust_remote_code=True` in LLM "
                "or using the `--trust-remote-code` flag in the CLI.")
87
88
89
            raise RuntimeError(err_msg) from e
        else:
            raise e
90
91
92
93
94
95
96
97
98
99
100
101
    except AttributeError as e:
        if "BaichuanTokenizer" in str(e):
            # This is for the error "'BaichuanTokenizer' object has no
            # attribute 'sp_model'".
            tokenizer = BaichuanTokenizer.from_pretrained(
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
                tokenizer_revision=tokenizer_revision,
                **kwargs)
        else:
            raise e
102
103
104
105
106

    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        logger.warning(
            "Using a slow tokenizer. This might cause a significant "
            "slowdown. Consider using a fast tokenizer instead.")
107
    return get_cached_tokenizer(tokenizer)
108
109


110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def get_lora_tokenizer(lora_request: LoRARequest, *args,
                       **kwargs) -> Optional[PreTrainedTokenizer]:
    if lora_request is None:
        return None
    try:
        tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
                                  **kwargs)
    except OSError as e:
        # No tokenizer was found in the LoRA folder,
        # use base model tokenizer
        logger.warning(
            f"No tokenizer found in {lora_request.lora_local_path}, "
            "using base model tokenizer instead. "
            f"(Exception: {str(e)})")
        tokenizer = None
    return tokenizer


get_lora_tokenizer_async = make_async(get_lora_tokenizer)


131
def _convert_tokens_to_string_with_added_encoders(
132
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
133
    output_tokens: List[str],
134
    skip_special_tokens: bool,
135
    spaces_between_special_tokens: bool,
136
) -> str:
137
138
    # Adapted from
    # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
139
140
141
142
143
    # NOTE(woosuk): The following code is slow because it runs a for loop over
    # the output_tokens. In Python, running a for loop over a list can be slow
    # even when the loop body is very simple.
    sub_texts = []
    current_sub_text = []
144
    all_special_tokens = set(tokenizer.all_special_tokens)
145
    for token in output_tokens:
146
        if skip_special_tokens and token in all_special_tokens:
147
            continue
148
        if token in tokenizer.get_added_vocab():
149
150
151
152
153
154
155
156
157
158
            if current_sub_text:
                sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
                sub_texts.append(sub_text)
                current_sub_text = []
            sub_texts.append(token)
        else:
            current_sub_text.append(token)
    if current_sub_text:
        sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
        sub_texts.append(sub_text)
159
160
161
162
    if spaces_between_special_tokens:
        return " ".join(sub_texts)
    else:
        return "".join(sub_texts)
163
164


165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# 5 is an arbitrary value that should work for all
# tokenizers (bigger = more conservative).
INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5


def convert_prompt_ids_to_tokens(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    prompt_ids: List[int],
    skip_special_tokens: bool = False,
) -> Tuple[List[str], int, int]:
    """Converts the prompt ids to tokens and returns the tokens and offsets
    for incremental detokenization.

    Note that not all tokens are converted to strings. Only the tokens that
    are necessary for incremental detokenization are converted to strings.
    """
    # Offset a little more in case we have special tokens.
    prefix_offset = max(
        len(prompt_ids) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2, 0)
    # We do not need to convert the whole prompt to tokens.
    new_tokens = tokenizer.convert_ids_to_tokens(
        prompt_ids[prefix_offset:], skip_special_tokens=skip_special_tokens)
    prefix_offset = max(
        len(new_tokens) - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
    read_offset = len(new_tokens)
    return new_tokens, prefix_offset, read_offset


193
194
195
196
197
198
199
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    all_input_ids: List[int],
    prev_tokens: Optional[List[str]],
200
201
    prefix_offset: int,
    read_offset: int,
202
    skip_special_tokens: bool = False,
203
    spaces_between_special_tokens: bool = True,
204
) -> Tuple[List[str], str, int, int]:
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
    """Detokenizes the input ids incrementally and returns the new tokens
    and the new text.

    If `prev_tokens` is None, this function will convert the input ids to
    tokens and return the tokens and the new text. Otherwise, it will return the
    new tokens and the new text.

    This function will also return the new prefix offset and the new read
    offset to be used in the next iteration.

    The offsets are necessary to defeat cleanup algorithms in the decode which
    decide to add a space or not depending on the surrounding ids.

    Args:
        tokenizer: The tokenizer to use.
        all_input_ids: The input ids. The last id is the new token id.
        prev_tokens: The previous tokens. If None, this function will convert
            the input ids to tokens and return the tokens and the new text.
        prefix_offset: The prefix offset.
        read_offset: The read offset.
        skip_special_tokens: Whether to skip special tokens.
        spaces_between_special_tokens: Whether to add spaces between special
            tokens.
    """
229
230
    new_token_id = all_input_ids[-1]
    # This is the first iteration for this sequence
231
232
233
234
235
236
237
238
    is_first_iter = prev_tokens is None
    if is_first_iter:
        (prev_tokens, prefix_offset,
         read_offset) = convert_prompt_ids_to_tokens(
             tokenizer,
             all_input_ids[:-1],
             skip_special_tokens=skip_special_tokens)

239
240
241
242
243
244
245
    # If the new token id is out of bounds, return an empty string.
    if new_token_id >= len(tokenizer):
        new_tokens = [""]
    else:
        # Put new_token_id in a list so skip_special_tokens is respected
        new_tokens = tokenizer.convert_ids_to_tokens(
            [new_token_id], skip_special_tokens=skip_special_tokens)
246
247
248
249
250
    output_tokens = prev_tokens + new_tokens

    # If this is the first iteration, return all tokens.
    if is_first_iter:
        new_tokens = output_tokens
251
252
253
254

    # The prefix text is necessary only to defeat cleanup algorithms in
    # the decode which decide to add a space or not depending on the
    # surrounding ids.
255
    if tokenizer.is_fast or not tokenizer.get_added_vocab():
256
257
258
259
260
261
262
263
        prefix_text = tokenizer.convert_tokens_to_string(
            output_tokens[prefix_offset:read_offset])
        new_text = tokenizer.convert_tokens_to_string(
            output_tokens[prefix_offset:])
    else:
        prefix_text = _convert_tokens_to_string_with_added_encoders(
            tokenizer,
            output_tokens[prefix_offset:read_offset],
264
265
266
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
        )
267
268
269
        new_text = _convert_tokens_to_string_with_added_encoders(
            tokenizer,
            output_tokens[prefix_offset:],
270
271
272
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
        )
273
274
275
276
277
278
279
280
281
282

    if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
        # utf-8 char at the end means it's a potential unfinished byte sequence
        # from byte fallback tokenization.
        # If it's in the middle, it's probably a real invalid id generated
        # by the model
        new_text = new_text[len(prefix_text):]
        return new_tokens, new_text, read_offset, len(output_tokens)
    else:
        return new_tokens, "", prefix_offset, read_offset