Unverified Commit 824a77d0 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix hf config loading (#702)

parent cf99eab7
......@@ -4,19 +4,26 @@ import functools
import json
import os
import warnings
from typing import AbstractSet, Collection, Literal, Optional, Union
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
from huggingface_hub import snapshot_download
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
from sglang.srt.utils import is_multimodal_model
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig,
}
def download_from_hf(model_path: str):
if os.path.exists(model_path):
......@@ -40,6 +47,9 @@ def get_config(
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision
)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
if model_overide_args:
config.update(model_overide_args)
return config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment