""" Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """Utilities for Huggingface Transformers.""" import contextlib import functools import json import os import warnings from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union from huggingface_hub import snapshot_download from transformers import ( AutoConfig, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, ) try: from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig from sglang.srt.configs import ExaoneConfig _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ChatGLMConfig.model_type: ChatGLMConfig, DbrxConfig.model_type: DbrxConfig, ExaoneConfig.model_type: ExaoneConfig, } except ImportError: # We want this file to run without vllm dependency _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {} for name, cls in _CONFIG_REGISTRY.items(): with contextlib.suppress(ValueError): AutoConfig.register(name, cls) def download_from_hf(model_path: str): if os.path.exists(model_path): return model_path return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"]) def get_config( model: str, trust_remote_code: bool, revision: Optional[str] = None, model_override_args: Optional[dict] = None, ): config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision ) if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) if model_override_args: config.update(model_override_args) return config # Models don't use the same configuration key for determining the maximum # context length. Store them here so we can sanely check them. # NOTE: The ordering here is important. Some models have two of these and we # have a preference for which value gets used. CONTEXT_LENGTH_KEYS = [ "max_sequence_length", "seq_length", "max_position_embeddings", "max_seq_len", "model_max_length", ] def get_context_length(config): """Get the context length of a model from a huggingface model configs.""" rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling: rope_scaling_factor = config.rope_scaling.get("factor", 1) if "original_max_position_embeddings" in rope_scaling: rope_scaling_factor = 1 if config.rope_scaling.get("rope_type", None) == "llama3": rope_scaling_factor = 1 else: rope_scaling_factor = 1 for key in CONTEXT_LENGTH_KEYS: val = getattr(config, key, None) if val is not None: return int(rope_scaling_factor * val) return 2048 # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" def get_tokenizer( tokenizer_name: str, *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, tokenizer_revision: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Gets a tokenizer for the given model name via Huggingface.""" if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, tokenizer_revision=tokenizer_revision, **kwargs, ) except TypeError as e: # The LLaMA tokenizer causes a protobuf error in some environments. err_msg = ( "Failed to load the tokenizer. If you are using a LLaMA V1 model " f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the " "original tokenizer." ) raise RuntimeError(err_msg) from e except ValueError as e: # If the error pertains to the tokenizer class not existing or not # currently being imported, suggest using the --trust-remote-code flag. if not trust_remote_code and ( "does not exist or is not currently imported." in str(e) or "requires you to execute the tokenizer file" in str(e) ): err_msg = ( "Failed to load the tokenizer. If the tokenizer is a custom " "tokenizer not yet available in the HuggingFace transformers " "library, consider setting `trust_remote_code=True` in LLM " "or using the `--trust-remote-code` flag in the CLI." ) raise RuntimeError(err_msg) from e else: raise e if not isinstance(tokenizer, PreTrainedTokenizerFast): warnings.warn( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead." ) return tokenizer def get_processor( tokenizer_name: str, *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, tokenizer_revision: Optional[str] = None, **kwargs, ): processor = AutoProcessor.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, tokenizer_revision=tokenizer_revision, **kwargs, ) return processor