hf_transformers_utils.py 5.94 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

Lianmin Zheng's avatar
Lianmin Zheng committed
16
17
"""Utilities for Huggingface Transformers."""

김종곤's avatar
김종곤 committed
18
import contextlib
19
import functools
Lianmin Zheng's avatar
Lianmin Zheng committed
20
21
22
import json
import os
import warnings
23
from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
Lianmin Zheng's avatar
Lianmin Zheng committed
24
25
26
27
28
29

from huggingface_hub import snapshot_download
from transformers import (
    AutoConfig,
    AutoProcessor,
    AutoTokenizer,
Ke Bao's avatar
Ke Bao committed
30
    PretrainedConfig,
Lianmin Zheng's avatar
Lianmin Zheng committed
31
32
33
    PreTrainedTokenizer,
    PreTrainedTokenizerFast,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
34
35
36
37

try:
    from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig

김종곤's avatar
김종곤 committed
38
39
    from sglang.srt.configs import ExaoneConfig

Lianmin Zheng's avatar
Lianmin Zheng committed
40
41
42
    _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
        ChatGLMConfig.model_type: ChatGLMConfig,
        DbrxConfig.model_type: DbrxConfig,
김종곤's avatar
김종곤 committed
43
        ExaoneConfig.model_type: ExaoneConfig,
Lianmin Zheng's avatar
Lianmin Zheng committed
44
45
46
47
    }
except ImportError:
    # We want this file to run without vllm dependency
    _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
Lianmin Zheng's avatar
Lianmin Zheng committed
48

김종곤's avatar
김종곤 committed
49
50
51
52
for name, cls in _CONFIG_REGISTRY.items():
    with contextlib.suppress(ValueError):
        AutoConfig.register(name, cls)

Lianmin Zheng's avatar
Lianmin Zheng committed
53
54
55
56
57
58
59
60

def download_from_hf(model_path: str):
    if os.path.exists(model_path):
        return model_path

    return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])


Yuanhan Zhang's avatar
Yuanhan Zhang committed
61
62
63
64
def get_config(
    model: str,
    trust_remote_code: bool,
    revision: Optional[str] = None,
65
    model_override_args: Optional[dict] = None,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
66
):
Lianmin Zheng's avatar
Lianmin Zheng committed
67
68
69
    config = AutoConfig.from_pretrained(
        model, trust_remote_code=trust_remote_code, revision=revision
    )
Ke Bao's avatar
Ke Bao committed
70
71
72
    if config.model_type in _CONFIG_REGISTRY:
        config_class = _CONFIG_REGISTRY[config.model_type]
        config = config_class.from_pretrained(model, revision=revision)
73
74
    if model_override_args:
        config.update(model_override_args)
Lianmin Zheng's avatar
Lianmin Zheng committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    return config


# Models don't use the same configuration key for determining the maximum
# context length.  Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# have a preference for which value gets used.
CONTEXT_LENGTH_KEYS = [
    "max_sequence_length",
    "seq_length",
    "max_position_embeddings",
    "max_seq_len",
    "model_max_length",
]


def get_context_length(config):
김종곤's avatar
김종곤 committed
92
    """Get the context length of a model from a huggingface model configs."""
Lianmin Zheng's avatar
Lianmin Zheng committed
93
94
95
    rope_scaling = getattr(config, "rope_scaling", None)
    if rope_scaling:
        rope_scaling_factor = config.rope_scaling["factor"]
Liangsheng Yin's avatar
Liangsheng Yin committed
96
97
        if "original_max_position_embeddings" in rope_scaling:
            rope_scaling_factor = 1
Liangsheng Yin's avatar
Liangsheng Yin committed
98
        if config.rope_scaling.get("rope_type", None) == "llama3":
99
            rope_scaling_factor = 1
Lianmin Zheng's avatar
Lianmin Zheng committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    else:
        rope_scaling_factor = 1

    for key in CONTEXT_LENGTH_KEYS:
        val = getattr(config, key, None)
        if val is not None:
            return int(rope_scaling_factor * val)
    return 2048


# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


def get_tokenizer(
    tokenizer_name: str,
    *args,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    tokenizer_revision: Optional[str] = None,
    **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    """Gets a tokenizer for the given model name via Huggingface."""
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False

    try:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_name,
            *args,
            trust_remote_code=trust_remote_code,
            tokenizer_revision=tokenizer_revision,
            **kwargs,
        )
    except TypeError as e:
        # The LLaMA tokenizer causes a protobuf error in some environments.
        err_msg = (
            "Failed to load the tokenizer. If you are using a LLaMA V1 model "
            f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
            "original tokenizer."
        )
        raise RuntimeError(err_msg) from e
    except ValueError as e:
        # If the error pertains to the tokenizer class not existing or not
        # currently being imported, suggest using the --trust-remote-code flag.
        if not trust_remote_code and (
            "does not exist or is not currently imported." in str(e)
            or "requires you to execute the tokenizer file" in str(e)
        ):
            err_msg = (
                "Failed to load the tokenizer. If the tokenizer is a custom "
                "tokenizer not yet available in the HuggingFace transformers "
                "library, consider setting `trust_remote_code=True` in LLM "
                "or using the `--trust-remote-code` flag in the CLI."
            )
            raise RuntimeError(err_msg) from e
        else:
            raise e

    if not isinstance(tokenizer, PreTrainedTokenizerFast):
        warnings.warn(
            "Using a slow tokenizer. This might cause a significant "
            "slowdown. Consider using a fast tokenizer instead."
        )
    return tokenizer


def get_processor(
    tokenizer_name: str,
    *args,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    tokenizer_revision: Optional[str] = None,
    **kwargs,
):
    processor = AutoProcessor.from_pretrained(
        tokenizer_name,
        *args,
        trust_remote_code=trust_remote_code,
        tokenizer_revision=tokenizer_revision,
        **kwargs,
    )
    return processor