Unverified Commit b21fdd53 authored by Kevin Tuan's avatar Kevin Tuan Committed by GitHub
Browse files

feat: (chat-template matching) enhance multimodal model detection with config.json (#9597)

parent c04c17ed
......@@ -26,6 +26,8 @@ Key components:
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses
import json
import os
import re
from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union
......@@ -959,16 +961,42 @@ register_conv_template(
)
MODEL_TYPE_TO_TEMPLATE = {
"internvl_chat": "internvl-2-5",
"deepseek_vl_v2": "deepseek-vl2",
"multi_modality": "janus-pro",
"phi4mm": "phi-4-mm",
"minicpmv": "minicpmv",
"minicpmo": "minicpmo",
}
def get_model_type(model_path: str) -> Optional[str]:
config_path = os.path.join(model_path, "config.json")
if not os.path.exists(config_path):
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
return config.get("model_type")
except (IOError, json.JSONDecodeError):
return None
@register_conv_template_matching_function
def match_internvl(model_path: str):
if re.search(r"internvl", model_path, re.IGNORECASE):
return "internvl-2-5"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if re.search(r"janus", model_path, re.IGNORECASE):
return "janus-pro"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function
......@@ -981,6 +1009,8 @@ def match_vicuna(model_path: str):
def match_deepseek_vl(model_path: str):
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
return "deepseek-vl2"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function
......@@ -994,14 +1024,17 @@ def match_qwen_chat_ml(model_path: str):
@register_conv_template_matching_function
def match_openbmb_minicpm(model_path: str):
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
return "minicpmv"
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
return "minicpmo"
def match_minicpm(model_path: str):
match = re.search(r"minicpm-(v|o)", model_path, re.IGNORECASE)
if match:
return f"minicpm{match.group(1).lower()}"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function
def match_phi_4_mm(model_path: str):
if "phi-4-multimodal" in model_path.lower():
return "phi-4-mm"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment