Commit fa9aaaa7 authored by myhloli's avatar myhloli
Browse files

fix: update model path handling in model.py and models_download_utils.py

parent c47faa4d
...@@ -62,7 +62,7 @@ class Mineru2QwenForCausalLM(nn.Module): ...@@ -62,7 +62,7 @@ class Mineru2QwenForCausalLM(nn.Module):
# load vision tower # load vision tower
mm_vision_tower = self.config.mm_vision_tower mm_vision_tower = self.config.mm_vision_tower
model_root_path = auto_download_and_get_model_root_path("/", "vlm") model_root_path = auto_download_and_get_model_root_path(mm_vision_tower, "vlm")
mm_vision_tower = f"{model_root_path}/{mm_vision_tower}" mm_vision_tower = f"{model_root_path}/{mm_vision_tower}"
if "clip" in mm_vision_tower: if "clip" in mm_vision_tower:
......
...@@ -57,8 +57,12 @@ def auto_download_and_get_model_root_path(relative_path: str, repo_mode='pipelin ...@@ -57,8 +57,12 @@ def auto_download_and_get_model_root_path(relative_path: str, repo_mode='pipelin
relative_path = relative_path.strip('/') relative_path = relative_path.strip('/')
cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"]) cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
elif repo_mode == 'vlm': elif repo_mode == 'vlm':
# VLM 模式下,直接下载整个模型目录 # VLM 模式下,根据 relative_path 的不同处理方式
cache_dir = snapshot_download(repo) if relative_path == "/":
cache_dir = snapshot_download(repo)
else:
relative_path = relative_path.strip('/')
cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
if not cache_dir: if not cache_dir:
raise FileNotFoundError(f"Failed to download model: {relative_path} from {repo}") raise FileNotFoundError(f"Failed to download model: {relative_path} from {repo}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment