tokenizer.py 7.63 KB
Newer Older
1
from typing import List, Optional, Tuple, Union
2

3
from transformers import (AutoTokenizer, PreTrainedTokenizer,
4
5
                          PreTrainedTokenizerFast)

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.logger import init_logger
7
from vllm.transformers_utils.tokenizers import *
8
9
10

logger = init_logger(__name__)

11
12
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
13
14
15


def get_tokenizer(
16
    tokenizer_name: str,
17
    *args,
18
    tokenizer_mode: str = "auto",
19
    trust_remote_code: bool = False,
20
    tokenizer_revision: Optional[str] = None,
21
22
    **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
23
    """Gets a tokenizer for the given model name via Huggingface."""
24
25
26
27
28
29
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

30
31
    if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
            and tokenizer_name != _FAST_LLAMA_TOKENIZER):
32
        logger.info(
33
34
            "For some LLaMA V1 models, initializing the fast tokenizer may "
            "take a long time. To reduce the initialization time, consider "
35
36
37
            f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
            "tokenizer.")
    try:
38
39
40
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            *args,
41
            trust_remote_code=trust_remote_code,
42
            tokenizer_revision=tokenizer_revision,
43
            **kwargs)
44
45
46
    except TypeError as e:
        # The LLaMA tokenizer causes a protobuf error in some environments.
        err_msg = (
47
48
49
            "Failed to load the tokenizer. If you are using a LLaMA V1 model "
            f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
            "original tokenizer.")
50
        raise RuntimeError(err_msg) from e
51
52
53
    except ValueError as e:
        # If the error pertains to the tokenizer class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
54
        if (not trust_remote_code and
55
56
57
58
59
            ("does not exist or is not currently imported." in str(e)
             or "requires you to execute the tokenizer file" in str(e))):
            err_msg = (
                "Failed to load the tokenizer. If the tokenizer is a custom "
                "tokenizer not yet available in the HuggingFace transformers "
60
61
                "library, consider setting `trust_remote_code=True` in LLM "
                "or using the `--trust-remote-code` flag in the CLI.")
62
63
64
            raise RuntimeError(err_msg) from e
        else:
            raise e
65
66
67
68
69
70
71
72
73
74
75
76
    except AttributeError as e:
        if "BaichuanTokenizer" in str(e):
            # This is for the error "'BaichuanTokenizer' object has no
            # attribute 'sp_model'".
            tokenizer = BaichuanTokenizer.from_pretrained(
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
                tokenizer_revision=tokenizer_revision,
                **kwargs)
        else:
            raise e
77
78
79
80
81
82

    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        logger.warning(
            "Using a slow tokenizer. This might cause a significant "
            "slowdown. Consider using a fast tokenizer instead.")
    return tokenizer
83
84


85
def _convert_tokens_to_string_with_added_encoders(
86
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
87
    output_tokens: List[str],
88
    skip_special_tokens: bool,
89
    spaces_between_special_tokens: bool,
90
) -> str:
91
92
    # Adapted from
    # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
93
94
95
96
97
    # NOTE(woosuk): The following code is slow because it runs a for loop over
    # the output_tokens. In Python, running a for loop over a list can be slow
    # even when the loop body is very simple.
    sub_texts = []
    current_sub_text = []
98
    all_special_tokens = set(tokenizer.all_special_tokens)
99
    for token in output_tokens:
100
        if skip_special_tokens and token in all_special_tokens:
101
            continue
102
        if token in tokenizer.get_added_vocab():
103
104
105
106
107
108
109
110
111
112
            if current_sub_text:
                sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
                sub_texts.append(sub_text)
                current_sub_text = []
            sub_texts.append(token)
        else:
            current_sub_text.append(token)
    if current_sub_text:
        sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
        sub_texts.append(sub_text)
113
114
115
116
    if spaces_between_special_tokens:
        return " ".join(sub_texts)
    else:
        return "".join(sub_texts)
117
118
119
120
121
122
123
124
125
126
127
128


# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    all_input_ids: List[int],
    prev_tokens: Optional[List[str]],
    prefix_offset: int = 0,
    read_offset: int = 0,
    skip_special_tokens: bool = False,
129
    spaces_between_special_tokens: bool = True,
130
131
132
133
134
135
136
137
138
139
140
) -> Tuple[List[str], str, int, int]:
    new_token_id = all_input_ids[-1]
    # This is the first iteration for this sequence
    if prev_tokens is None:
        new_tokens = tokenizer.convert_ids_to_tokens(
            all_input_ids, skip_special_tokens=skip_special_tokens)
        output_tokens = new_tokens
        # 5 is an arbitrary value that should work for all
        # tokenizers (bigger = more conservative).
        # Subtract 1 extra to account for the generated token.
        prefix_offset = max(len(output_tokens) - 6, 0)
141
142
143
144
145
        # If the first new token is a special token, we can't skip 1 extra token
        if skip_special_tokens and new_token_id in tokenizer.all_special_ids:
            read_offset = max(len(output_tokens), 0)
        else:
            read_offset = max(len(output_tokens) - 1, 0)
146
    else:
147
148
149
        # Put new_token_id in a list so skip_special_tokens is respected
        new_tokens = tokenizer.convert_ids_to_tokens(
            [new_token_id], skip_special_tokens=skip_special_tokens)
150
151
152
153
154
        output_tokens = prev_tokens + new_tokens

    # The prefix text is necessary only to defeat cleanup algorithms in
    # the decode which decide to add a space or not depending on the
    # surrounding ids.
155
    if tokenizer.is_fast or not tokenizer.get_added_vocab():
156
157
158
159
160
161
162
163
        prefix_text = tokenizer.convert_tokens_to_string(
            output_tokens[prefix_offset:read_offset])
        new_text = tokenizer.convert_tokens_to_string(
            output_tokens[prefix_offset:])
    else:
        prefix_text = _convert_tokens_to_string_with_added_encoders(
            tokenizer,
            output_tokens[prefix_offset:read_offset],
164
165
166
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
        )
167
168
169
        new_text = _convert_tokens_to_string_with_added_encoders(
            tokenizer,
            output_tokens[prefix_offset:],
170
171
172
            skip_special_tokens=skip_special_tokens,
            spaces_between_special_tokens=spaces_between_special_tokens,
        )
173
174
175
176
177
178
179
180
181
182

    if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
        # utf-8 char at the end means it's a potential unfinished byte sequence
        # from byte fallback tokenization.
        # If it's in the middle, it's probably a real invalid id generated
        # by the model
        new_text = new_text[len(prefix_text):]
        return new_tokens, new_text, read_offset, len(output_tokens)
    else:
        return new_tokens, "", prefix_offset, read_offset