tokenizer.py 6.38 KB
Newer Older
1
from typing import List, Optional, Tuple, Union
2

3
from transformers import (AutoTokenizer, PreTrainedTokenizer,
4
5
                          PreTrainedTokenizerFast)

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.logger import init_logger
7
8
9

logger = init_logger(__name__)

10
11
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
12
13
14


def get_tokenizer(
15
    tokenizer_name: str,
16
    *args,
17
    tokenizer_mode: str = "auto",
18
    trust_remote_code: bool = False,
19
20
    **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
21
    """Gets a tokenizer for the given model name via Huggingface."""
22
23
24
25
26
27
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

28
29
    if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True)
            and tokenizer_name != _FAST_LLAMA_TOKENIZER):
30
        logger.info(
31
32
33
34
35
            "For some LLaMA-based models, initializing the fast tokenizer may "
            "take a long time. To eliminate the initialization time, consider "
            f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
            "tokenizer.")
    try:
36
37
38
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            *args,
39
            trust_remote_code=trust_remote_code,
40
            **kwargs)
41
42
43
44
45
46
47
    except TypeError as e:
        # The LLaMA tokenizer causes a protobuf error in some environments.
        err_msg = (
            "Failed to load the tokenizer. If you are using a LLaMA-based "
            f"model, use '{_FAST_LLAMA_TOKENIZER}' instead of the original "
            "tokenizer.")
        raise RuntimeError(err_msg) from e
48
49
50
    except ValueError as e:
        # If the error pertains to the tokenizer class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
51
        if (not trust_remote_code and
52
53
54
55
56
            ("does not exist or is not currently imported." in str(e)
             or "requires you to execute the tokenizer file" in str(e))):
            err_msg = (
                "Failed to load the tokenizer. If the tokenizer is a custom "
                "tokenizer not yet available in the HuggingFace transformers "
57
58
                "library, consider setting `trust_remote_code=True` in LLM "
                "or using the `--trust-remote-code` flag in the CLI.")
59
60
61
            raise RuntimeError(err_msg) from e
        else:
            raise e
62
63
64
65
66
67

    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        logger.warning(
            "Using a slow tokenizer. This might cause a significant "
            "slowdown. Consider using a fast tokenizer instead.")
    return tokenizer
68
69


70
def _convert_tokens_to_string_with_added_encoders(
71
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
72
    output_tokens: List[str],
73
    skip_special_tokens: bool,
74
) -> str:
75
76
    # Adapted from
    # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
77
78
79
80
81
82
    # NOTE(woosuk): The following code is slow because it runs a for loop over
    # the output_tokens. In Python, running a for loop over a list can be slow
    # even when the loop body is very simple.
    sub_texts = []
    current_sub_text = []
    for token in output_tokens:
83
        if skip_special_tokens and token in tokenizer.all_special_tokens:
84
85
86
87
88
89
90
91
92
93
94
95
            continue
        if token in tokenizer.added_tokens_encoder:
            if current_sub_text:
                sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
                sub_texts.append(sub_text)
                current_sub_text = []
            sub_texts.append(token)
        else:
            current_sub_text.append(token)
    if current_sub_text:
        sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
        sub_texts.append(sub_text)
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    return " ".join(sub_texts)


# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    all_input_ids: List[int],
    prev_tokens: Optional[List[str]],
    prefix_offset: int = 0,
    read_offset: int = 0,
    skip_special_tokens: bool = False,
) -> Tuple[List[str], str, int, int]:
    new_token_id = all_input_ids[-1]
    # This is the first iteration for this sequence
    if prev_tokens is None:
        new_tokens = tokenizer.convert_ids_to_tokens(
            all_input_ids, skip_special_tokens=skip_special_tokens)
        output_tokens = new_tokens
        # 5 is an arbitrary value that should work for all
        # tokenizers (bigger = more conservative).
        # Subtract 1 extra to account for the generated token.
        prefix_offset = max(len(output_tokens) - 6, 0)
        read_offset = max(len(output_tokens) - 1, 0)
    else:
        new_token = tokenizer.convert_ids_to_tokens(
            new_token_id, skip_special_tokens=skip_special_tokens)
        new_tokens = [new_token]
        output_tokens = prev_tokens + new_tokens

    # The prefix text is necessary only to defeat cleanup algorithms in
    # the decode which decide to add a space or not depending on the
    # surrounding ids.
    if not getattr(tokenizer, "added_tokens_encoder", {}):
        prefix_text = tokenizer.convert_tokens_to_string(
            output_tokens[prefix_offset:read_offset])
        new_text = tokenizer.convert_tokens_to_string(
            output_tokens[prefix_offset:])
    else:
        prefix_text = _convert_tokens_to_string_with_added_encoders(
            tokenizer,
            output_tokens[prefix_offset:read_offset],
            skip_special_tokens=skip_special_tokens)
        new_text = _convert_tokens_to_string_with_added_encoders(
            tokenizer,
            output_tokens[prefix_offset:],
            skip_special_tokens=skip_special_tokens)

    if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
        # utf-8 char at the end means it's a potential unfinished byte sequence
        # from byte fallback tokenization.
        # If it's in the middle, it's probably a real invalid id generated
        # by the model
        new_text = new_text[len(prefix_text):]
        return new_tokens, new_text, read_offset, len(output_tokens)
    else:
        return new_tokens, "", prefix_offset, read_offset