tokenizer.py 4.49 KB
Newer Older
1
from typing import Optional, Union
2

3
from transformers import (AutoTokenizer, PreTrainedTokenizer,
4
5
                          PreTrainedTokenizerFast)

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.logger import init_logger
7
from vllm.lora.request import LoRARequest
8
from vllm.transformers_utils.tokenizers import *
9
from vllm.utils import make_async
10
11
12

logger = init_logger(__name__)

13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def get_cached_tokenizer(
    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    """Get tokenizer with cached properties.

    This will patch the tokenizer object in place.

    By default, transformers will recompute multiple tokenizer properties
    each time they are called, leading to a significant slowdown. This
    function caches these properties for faster access."""

    tokenizer_all_special_ids = set(tokenizer.all_special_ids)
    tokenizer_all_special_tokens_extended = (
        tokenizer.all_special_tokens_extended)
    tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
29
    tokenizer_len = len(tokenizer)
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

    class CachedTokenizer(tokenizer.__class__):

        @property
        def all_special_ids(self):
            return tokenizer_all_special_ids

        @property
        def all_special_tokens(self):
            return tokenizer_all_special_tokens

        @property
        def all_special_tokens_extended(self):
            return tokenizer_all_special_tokens_extended

45
46
47
        def __len__(self):
            return tokenizer_len

48
49
50
51
52
53
    CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

    tokenizer.__class__ = CachedTokenizer
    return tokenizer


54
def get_tokenizer(
55
    tokenizer_name: str,
56
    *args,
57
    tokenizer_mode: str = "auto",
58
    trust_remote_code: bool = False,
59
    tokenizer_revision: Optional[str] = None,
60
61
    **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
62
    """Gets a tokenizer for the given model name via Huggingface."""
63
64
65
66
67
68
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

69
    try:
70
71
72
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            *args,
73
            trust_remote_code=trust_remote_code,
74
            tokenizer_revision=tokenizer_revision,
75
76
77
78
            **kwargs)
    except ValueError as e:
        # If the error pertains to the tokenizer class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
79
        if (not trust_remote_code and
80
81
82
83
84
            ("does not exist or is not currently imported." in str(e)
             or "requires you to execute the tokenizer file" in str(e))):
            err_msg = (
                "Failed to load the tokenizer. If the tokenizer is a custom "
                "tokenizer not yet available in the HuggingFace transformers "
85
86
                "library, consider setting `trust_remote_code=True` in LLM "
                "or using the `--trust-remote-code` flag in the CLI.")
87
88
89
            raise RuntimeError(err_msg) from e
        else:
            raise e
90
91
92
93
94
95
96
97
98
99
100
101
    except AttributeError as e:
        if "BaichuanTokenizer" in str(e):
            # This is for the error "'BaichuanTokenizer' object has no
            # attribute 'sp_model'".
            tokenizer = BaichuanTokenizer.from_pretrained(
                tokenizer_name,
                *args,
                trust_remote_code=trust_remote_code,
                tokenizer_revision=tokenizer_revision,
                **kwargs)
        else:
            raise e
102
103
104
105
106

    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        logger.warning(
            "Using a slow tokenizer. This might cause a significant "
            "slowdown. Consider using a fast tokenizer instead.")
107
    return get_cached_tokenizer(tokenizer)
108
109


110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def get_lora_tokenizer(lora_request: LoRARequest, *args,
                       **kwargs) -> Optional[PreTrainedTokenizer]:
    if lora_request is None:
        return None
    try:
        tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
                                  **kwargs)
    except OSError as e:
        # No tokenizer was found in the LoRA folder,
        # use base model tokenizer
        logger.warning(
            f"No tokenizer found in {lora_request.lora_local_path}, "
            "using base model tokenizer instead. "
            f"(Exception: {str(e)})")
        tokenizer = None
    return tokenizer


get_lora_tokenizer_async = make_async(get_lora_tokenizer)