Unverified Commit fed02a49 authored by FlyPanda's avatar FlyPanda Committed by GitHub
Browse files

[bugfix] fix deepseekvl2 and deepseek_ocr model type conflict (#12050)


Co-authored-by: default avatarherta <herta@pplabs.org>
parent 03b3e89a
......@@ -197,14 +197,6 @@ def get_config(
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if (
getattr(config, "auto_map", None) is not None
and config.auto_map.get("AutoModel")
== "modeling_deepseekocr.DeepseekOCRForCausalLM"
):
config.model_type = "deepseek-ocr"
# TODO: Remove this workaround when AutoConfig correctly identifies deepseek-ocr.
# Hugging Face's AutoConfig currently misidentifies it as deepseekvl2.
except ValueError as e:
if not "deepseek_v32" in str(e):
......@@ -241,7 +233,15 @@ def get_config(
setattr(config, key, val)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
model_type = config.model_type
if model_type == "deepseek_vl_v2":
if (
getattr(config, "auto_map", None) is not None
and config.auto_map.get("AutoModel")
== "modeling_deepseekocr.DeepseekOCRForCausalLM"
):
model_type = "deepseek-ocr"
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(model, revision=revision)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
setattr(config, "_name_or_path", model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment