"tests/vscode:/vscode.git/clone" did not exist on "a4ec0c559521c055519eeabddd8279c83eb4e936"
tokenizer.py 5.71 KB
Newer Older
1
import os
2
from pathlib import Path
3
from typing import Optional, Union
4

5
import huggingface_hub
6
from transformers import (AutoTokenizer, PreTrainedTokenizer,
7
8
                          PreTrainedTokenizerFast)

9
from vllm.envs import VLLM_USE_MODELSCOPE
Woosuk Kwon's avatar
Woosuk Kwon committed
10
from vllm.logger import init_logger
11
from vllm.lora.request import LoRARequest
12
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
13
from vllm.utils import make_async
14
15
16

logger = init_logger(__name__)

17

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def get_cached_tokenizer(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    """Get tokenizer with cached properties.

    This will patch the tokenizer object in place.

    By default, transformers will recompute multiple tokenizer properties
    each time they are called, leading to a significant slowdown. This
    function caches these properties for faster access."""

    tokenizer_all_special_ids = set(tokenizer.all_special_ids)
    tokenizer_all_special_tokens_extended = (
        tokenizer.all_special_tokens_extended)
    tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
33
    tokenizer_len = len(tokenizer)
34

35
    class CachedTokenizer(tokenizer.__class__):  # type: ignore
36
37
38
39
40
41
42
43
44
45
46
47
48

        @property
        def all_special_ids(self):
            return tokenizer_all_special_ids

        @property
        def all_special_tokens(self):
            return tokenizer_all_special_tokens

        @property
        def all_special_tokens_extended(self):
            return tokenizer_all_special_tokens_extended

49
50
51
        def __len__(self):
            return tokenizer_len

52
53
54
55
56
57
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

    tokenizer.__class__ = CachedTokenizer
    return tokenizer


58
def get_tokenizer(
59
    tokenizer_name: Union[str, Path],
60
    *args,
61
    tokenizer_mode: str = "auto",
62
    trust_remote_code: bool = False,
63
    revision: Optional[str] = None,
64
    download_dir: Optional[str] = None,
65
66
    **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
67
68
    """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
    """
69
70
71
72
73
74
75
76
77
78
79
    if VLLM_USE_MODELSCOPE:
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

        # Only set the tokenizer here, model will be downloaded on the workers.
        if not os.path.exists(tokenizer_name):
            tokenizer_path = snapshot_download(
                model_id=tokenizer_name,
                cache_dir=download_dir,
80
                revision=revision,
81
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
82
                # Ignore weights - we only need the tokenizer.
83
                ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
84
85
            tokenizer_name = tokenizer_path

86
87
88
89
90
91
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

92
93
94
    if "truncation_side" not in kwargs:
        kwargs["truncation_side"] = "left"

95
96
97
98
99
100
101
    # Separate model folder from file path for GGUF models
    is_gguf = Path(tokenizer_name).is_file() and Path(
        tokenizer_name).suffix == ".gguf"
    if is_gguf:
        kwargs["gguf_file"] = Path(tokenizer_name).name
        tokenizer_name = Path(tokenizer_name).parent

102
    try:
103
104
105
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            *args,
106
            trust_remote_code=trust_remote_code,
107
            revision=revision,
108
109
110
111
            **kwargs)
    except ValueError as e:
        # If the error pertains to the tokenizer class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
112
        if (not trust_remote_code and
113
114
115
116
117
            ("does not exist or is not currently imported." in str(e)
             or "requires you to execute the tokenizer file" in str(e))):
            err_msg = (
                "Failed to load the tokenizer. If the tokenizer is a custom "
                "tokenizer not yet available in the HuggingFace transformers "
118
119
                "library, consider setting `trust_remote_code=True` in LLM "
                "or using the `--trust-remote-code` flag in the CLI.")
120
121
122
            raise RuntimeError(err_msg) from e
        else:
            raise e
123
124
125
126
127
128
129
130
    except AttributeError as e:
        if "BaichuanTokenizer" in str(e):
            # This is for the error "'BaichuanTokenizer' object has no
            # attribute 'sp_model'".
            tokenizer = BaichuanTokenizer.from_pretrained(
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
131
                revision=revision,
132
133
134
                **kwargs)
        else:
            raise e
135
136
137
138
139

    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        logger.warning(
            "Using a slow tokenizer. This might cause a significant "
            "slowdown. Consider using a fast tokenizer instead.")
140
    return get_cached_tokenizer(tokenizer)
141
142


143
144
145
146
147
def get_lora_tokenizer(lora_request: LoRARequest, *args,
                       **kwargs) -> Optional[PreTrainedTokenizer]:
    if lora_request is None:
        return None
    try:
148
        tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
149
150
151
152
    except OSError as e:
        # No tokenizer was found in the LoRA folder,
        # use base model tokenizer
        logger.warning(
153
            "No tokenizer found in %s, using base model tokenizer instead. "
154
            "(Exception: %s)", lora_request.lora_path, e)
155
156
157
158
159
        tokenizer = None
    return tokenizer


get_lora_tokenizer_async = make_async(get_lora_tokenizer)