Unverified Commit 824a77d0 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix hf config loading (#702)

parent cf99eab7
...@@ -4,19 +4,26 @@ import functools ...@@ -4,19 +4,26 @@ import functools
import json import json
import os import os
import warnings import warnings
from typing import AbstractSet, Collection, Literal, Optional, Union from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
) )
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
from sglang.srt.utils import is_multimodal_model from sglang.srt.utils import is_multimodal_model
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig,
}
def download_from_hf(model_path: str): def download_from_hf(model_path: str):
if os.path.exists(model_path): if os.path.exists(model_path):
...@@ -40,6 +47,9 @@ def get_config( ...@@ -40,6 +47,9 @@ def get_config(
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision model, trust_remote_code=trust_remote_code, revision=revision
) )
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
if model_overide_args: if model_overide_args:
config.update(model_overide_args) config.update(model_overide_args)
return config return config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment