Unverified Commit b21fdd53 authored by Kevin Tuan's avatar Kevin Tuan Committed by GitHub
Browse files

feat: (chat-template matching) enhance multimodal model detection with config.json (#9597)

parent c04c17ed
...@@ -26,6 +26,8 @@ Key components: ...@@ -26,6 +26,8 @@ Key components:
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses import dataclasses
import json
import os
import re import re
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
...@@ -959,16 +961,42 @@ register_conv_template( ...@@ -959,16 +961,42 @@ register_conv_template(
) )
MODEL_TYPE_TO_TEMPLATE = {
"internvl_chat": "internvl-2-5",
"deepseek_vl_v2": "deepseek-vl2",
"multi_modality": "janus-pro",
"phi4mm": "phi-4-mm",
"minicpmv": "minicpmv",
"minicpmo": "minicpmo",
}
def get_model_type(model_path: str) -> Optional[str]:
config_path = os.path.join(model_path, "config.json")
if not os.path.exists(config_path):
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
return config.get("model_type")
except (IOError, json.JSONDecodeError):
return None
@register_conv_template_matching_function @register_conv_template_matching_function
def match_internvl(model_path: str): def match_internvl(model_path: str):
if re.search(r"internvl", model_path, re.IGNORECASE): if re.search(r"internvl", model_path, re.IGNORECASE):
return "internvl-2-5" return "internvl-2-5"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function @register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str): def match_deepseek_janus_pro(model_path: str):
if re.search(r"janus", model_path, re.IGNORECASE): if re.search(r"janus", model_path, re.IGNORECASE):
return "janus-pro" return "janus-pro"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function @register_conv_template_matching_function
...@@ -981,6 +1009,8 @@ def match_vicuna(model_path: str): ...@@ -981,6 +1009,8 @@ def match_vicuna(model_path: str):
def match_deepseek_vl(model_path: str): def match_deepseek_vl(model_path: str):
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE): if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
return "deepseek-vl2" return "deepseek-vl2"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function @register_conv_template_matching_function
...@@ -994,14 +1024,17 @@ def match_qwen_chat_ml(model_path: str): ...@@ -994,14 +1024,17 @@ def match_qwen_chat_ml(model_path: str):
@register_conv_template_matching_function @register_conv_template_matching_function
def match_openbmb_minicpm(model_path: str): def match_minicpm(model_path: str):
if re.search(r"minicpm-v", model_path, re.IGNORECASE): match = re.search(r"minicpm-(v|o)", model_path, re.IGNORECASE)
return "minicpmv" if match:
elif re.search(r"minicpm-o", model_path, re.IGNORECASE): return f"minicpm{match.group(1).lower()}"
return "minicpmo" model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function @register_conv_template_matching_function
def match_phi_4_mm(model_path: str): def match_phi_4_mm(model_path: str):
if "phi-4-multimodal" in model_path.lower(): if "phi-4-multimodal" in model_path.lower():
return "phi-4-mm" return "phi-4-mm"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment