hf_transformers_utils.py 11.1 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
"""Utilities for Huggingface Transformers."""

3
import functools
Lianmin Zheng's avatar
Lianmin Zheng committed
4
5
6
import json
import os
import warnings
Ke Bao's avatar
Ke Bao committed
7
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
8
9
10
11
12
13

from huggingface_hub import snapshot_download
from transformers import (
    AutoConfig,
    AutoProcessor,
    AutoTokenizer,
Ke Bao's avatar
Ke Bao committed
14
    PretrainedConfig,
Lianmin Zheng's avatar
Lianmin Zheng committed
15
16
17
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
)
Ke Bao's avatar
Ke Bao committed
18
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
Lianmin Zheng's avatar
Lianmin Zheng committed
19

Liangsheng Yin's avatar
Liangsheng Yin committed
20
21
from sglang.srt.utils import is_multimodal_model

Ke Bao's avatar
Ke Bao committed
22
23
24
25
26
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
    ChatGLMConfig.model_type: ChatGLMConfig,
    DbrxConfig.model_type: DbrxConfig,
}

Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40

def download_from_hf(model_path: str):
    if os.path.exists(model_path):
        return model_path

    return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])


def get_config_json(model_path: str):
    with open(os.path.join(model_path, "config.json")) as f:
        config = json.load(f)
    return config


Yuanhan Zhang's avatar
Yuanhan Zhang committed
41
42
43
44
45
46
def get_config(
    model: str,
    trust_remote_code: bool,
    revision: Optional[str] = None,
    model_overide_args: Optional[dict] = None,
):
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
    config = AutoConfig.from_pretrained(
        model, trust_remote_code=trust_remote_code, revision=revision
    )
Ke Bao's avatar
Ke Bao committed
50
51
52
    if config.model_type in _CONFIG_REGISTRY:
        config_class = _CONFIG_REGISTRY[config.model_type]
        config = config_class.from_pretrained(model, revision=revision)
Yuanhan Zhang's avatar
Yuanhan Zhang committed
53
54
    if model_overide_args:
        config.update(model_overide_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    return config


# Models don't use the same configuration key for determining the maximum
# context length.  Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# have a preference for which value gets used.
CONTEXT_LENGTH_KEYS = [
    "max_sequence_length",
    "seq_length",
    "max_position_embeddings",
    "max_seq_len",
    "model_max_length",
]


def get_context_length(config):
    """Get the context length of a model from a huggingface model config."""
    rope_scaling = getattr(config, "rope_scaling", None)
    if rope_scaling:
        rope_scaling_factor = config.rope_scaling["factor"]
Liangsheng Yin's avatar
Liangsheng Yin committed
76
77
        if "original_max_position_embeddings" in rope_scaling:
            rope_scaling_factor = 1
Liangsheng Yin's avatar
Liangsheng Yin committed
78
        if config.rope_scaling.get("rope_type", None) == "llama3":
79
            rope_scaling_factor = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    else:
        rope_scaling_factor = 1

    for key in CONTEXT_LENGTH_KEYS:
        val = getattr(config, key, None)
        if val is not None:
            return int(rope_scaling_factor * val)
    return 2048


# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


def get_tokenizer(
    tokenizer_name: str,
    *args,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    tokenizer_revision: Optional[str] = None,
    **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
102
103
104
    if tokenizer_name.endswith(".json"):
        return TiktokenTokenizer(tokenizer_name)

105
106
107
    if tokenizer_name.endswith(".model"):
        return SentencePieceTokenizer(tokenizer_name)

Lianmin Zheng's avatar
Lianmin Zheng committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    """Gets a tokenizer for the given model name via Huggingface."""
    if is_multimodal_model(tokenizer_name):
        processor = get_processor(
            tokenizer_name,
            *args,
            trust_remote_code=trust_remote_code,
            tokenizer_revision=tokenizer_revision,
            **kwargs,
        )
        tokenizer = processor.tokenizer
        return tokenizer

    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

    if (
        "llama" in tokenizer_name.lower()
        and kwargs.get("use_fast", True)
        and tokenizer_name != _FAST_LLAMA_TOKENIZER
    ):
        pass
        # warnings.warn(
        #    "For some LLaMA V1 models, initializing the fast tokenizer may "
        #    "take a long time. To reduce the initialization time, consider "
        #    f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
        #    "tokenizer."
        # )
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            *args,
            trust_remote_code=trust_remote_code,
            tokenizer_revision=tokenizer_revision,
            **kwargs,
        )
    except TypeError as e:
        # The LLaMA tokenizer causes a protobuf error in some environments.
        err_msg = (
            "Failed to load the tokenizer. If you are using a LLaMA V1 model "
            f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
            "original tokenizer."
        )
        raise RuntimeError(err_msg) from e
    except ValueError as e:
        # If the error pertains to the tokenizer class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        if not trust_remote_code and (
            "does not exist or is not currently imported." in str(e)
            or "requires you to execute the tokenizer file" in str(e)
        ):
            err_msg = (
                "Failed to load the tokenizer. If the tokenizer is a custom "
                "tokenizer not yet available in the HuggingFace transformers "
                "library, consider setting `trust_remote_code=True` in LLM "
                "or using the `--trust-remote-code` flag in the CLI."
            )
            raise RuntimeError(err_msg) from e
        else:
            raise e

    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        warnings.warn(
            "Using a slow tokenizer. This might cause a significant "
            "slowdown. Consider using a fast tokenizer instead."
        )
    return tokenizer


def get_processor(
    tokenizer_name: str,
    *args,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    tokenizer_revision: Optional[str] = None,
    **kwargs,
):
    processor = AutoProcessor.from_pretrained(
        tokenizer_name,
        *args,
        trust_remote_code=trust_remote_code,
        tokenizer_revision=tokenizer_revision,
        **kwargs,
    )
    return processor
194
195
196
197


class TiktokenTokenizer:
    def __init__(self, tokenizer_path):
Lianmin Zheng's avatar
Lianmin Zheng committed
198
        import tiktoken
199
        from jinja2 import Template
200

Lianmin Zheng's avatar
Lianmin Zheng committed
201
202
203
204
205
206
207
208
209
210
211
        PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""

        # Read JSON
        name = "tmp-json"
        with open(tokenizer_path, "rb") as fin:
            tok_dict = json.load(fin)

        mergeable_ranks = {
            bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
        }
        special_tokens = {
212
213
            bytes(item["bytes"]).decode(): item["token"]
            for item in tok_dict["special_tokens"]
Lianmin Zheng's avatar
Lianmin Zheng committed
214
215
216
217
218
219
220
221
222
223
224
        }
        assert tok_dict["word_split"] == "V1"

        kwargs = {
            "name": name,
            "pat_str": tok_dict.get("pat_str", PAT_STR_B),
            "mergeable_ranks": mergeable_ranks,
            "special_tokens": special_tokens,
        }
        if "default_allowed_special" in tok_dict:
            default_allowed_special = set(
225
226
227
228
                [
                    bytes(bytes_list).decode()
                    for bytes_list in tok_dict["default_allowed_special"]
                ]
Lianmin Zheng's avatar
Lianmin Zheng committed
229
230
231
232
233
234
235
236
            )
        else:
            default_allowed_special = None
        if "vocab_size" in tok_dict:
            kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]

        tokenizer = tiktoken.Encoding(**kwargs)
        tokenizer._default_allowed_special = default_allowed_special or set()
237
        tokenizer._default_allowed_special |= {"<|separator|>"}
Lianmin Zheng's avatar
Lianmin Zheng committed
238
239
240
241
242

        def encode_patched(
            self,
            text: str,
            *,
243
244
245
            allowed_special: Union[
                Literal["all"], AbstractSet[str]
            ] = set(),  # noqa: B006
Lianmin Zheng's avatar
Lianmin Zheng committed
246
247
248
249
250
            disallowed_special: Union[Literal["all"], Collection[str]] = "all",
        ) -> list[int]:
            if isinstance(allowed_special, set):
                allowed_special |= self._default_allowed_special
            return tiktoken.Encoding.encode(
251
252
253
254
                self,
                text,
                allowed_special=allowed_special,
                disallowed_special=disallowed_special,
Lianmin Zheng's avatar
Lianmin Zheng committed
255
            )
256

Lianmin Zheng's avatar
Lianmin Zheng committed
257
258
259
        tokenizer.encode = functools.partial(encode_patched, tokenizer)

        # Convert to HF interface
260
        self.tokenizer = tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
261
        self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
262
        self.vocab_size = tokenizer.n_vocab
263
264
265
        self.chat_template = Template(
            "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: '  + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
        )
266

267
    def encode(self, x, add_special_tokens=False):
268
269
270
271
272
        return self.tokenizer.encode(x)

    def decode(self, x):
        return self.tokenizer.decode(x)

273
274
275
    def batch_decode(
        self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
276
277
        if isinstance(batch[0], int):
            batch = [[x] for x in batch]
278
279
        return self.tokenizer.decode_batch(batch)

280
    def apply_chat_template(self, messages, tokenize, add_generation_prompt):
Ying Sheng's avatar
Ying Sheng committed
281
282
283
        ret = self.chat_template.render(
            messages=messages, add_generation_prompt=add_generation_prompt
        )
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        return self.encode(ret) if tokenize else ret


class SentencePieceTokenizer:
    def __init__(self, tokenizer_path):
        import sentencepiece as spm
        from jinja2 import Template

        tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)

        # Convert to HF interface
        self.tokenizer = tokenizer
        self.eos_token_id = tokenizer.eos_id()
        self.vocab_size = tokenizer.vocab_size()
        self.chat_template = Template(
            "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: '  + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
300
        )
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

    def encode(self, x, add_special_tokens=False):
        return self.tokenizer.encode(x)

    def decode(self, x):
        return self.tokenizer.decode(x)

    def batch_decode(
        self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
    ):
        if isinstance(batch[0], int):
            batch = [[x] for x in batch]
        return self.tokenizer.decode(batch)

    def apply_chat_template(self, messages, tokenize, add_generation_prompt):
Ying Sheng's avatar
Ying Sheng committed
316
317
318
319
        ret = self.chat_template.render(
            messages=messages, add_generation_prompt=add_generation_prompt
        )
        return self.encode(ret) if tokenize else ret