# 简单修复:只导入必要的模块,避免与huggingface_hub的冲突 import sys import os import json import yaml import time import threading import queue import shutil import multiprocessing from datetime import datetime from typing import Dict, List, Optional # 移除了huggingface_hub的Mock替换,因为已经安装了真实的库 # 仅导入Gradio的基本组件,避免使用需要huggingface_hub的功能 import gradio as gr from typing import Dict, List, Optional try: from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.api import HubApi HAS_MODELSCOPE = True except ImportError: print("警告: modelscope未安装") HAS_MODELSCOPE = False try: from pycsghub.csghub_api import CsgHubApi from pycsghub.upload_large_folder.main import upload_large_folder_internal, create_repo HAS_PYCSGHUB = True except ImportError: print("警告: pycsghub未安装") HAS_PYCSGHUB = False CONFIG_FILE = "config.yaml" STATE_FILE = "model_manager_state.json" DEFAULT_CONFIG = { "local": { "default_model_path": "", "models_db_path": "models_db.json", "config_db_path": "config_db.json" }, "csghub": { "base_url": "", "token": "", "repo_type": "model", "revision": "main" }, "download": { "max_retries": 10, "retry_interval": 5, "max_concurrent": 1 }, "upload": { "create_repo_default": True, "num_workers": 1 } } class GlobalState: _instance = None _lock = threading.RLock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return self._initialized = True self.config = DEFAULT_CONFIG.copy() self.state = { "download_tasks": {}, "upload_tasks": {}, "local_models": [], "remote_cache": {} } self.download_queue = queue.PriorityQueue() self.upload_queue = queue.Queue() self.active_downloads = {} # task_id -> model_id self.download_processes = {} # task_id -> process object self.active_uploads = {} self.operation_lock = threading.RLock() self.load_config() self.load_state() def load_config(self): if os.path.exists(CONFIG_FILE): try: with open(CONFIG_FILE, 'r', encoding='utf-8') as f: loaded = yaml.safe_load(f) if loaded: # Merge loaded config with default config for key, value in loaded.items(): if key in self.config: if isinstance(value, dict) and isinstance(self.config[key], dict): self.config[key].update(value) else: self.config[key] = value except Exception as e: print(f"加载配置失败: {e}") def save_config(self): try: with open(CONFIG_FILE, 'w', encoding='utf-8') as f: yaml.dump(self.config, f, default_flow_style=False, allow_unicode=True) except Exception as e: print(f"保存配置失败: {e}") def load_state(self): if os.path.exists(STATE_FILE): try: with open(STATE_FILE, 'r', encoding='utf-8') as f: loaded = json.load(f) # 加载下载任务并将所有需要下载的模型初始状态设为pause download_tasks = loaded.get("download_tasks", {}) for task_id, task in download_tasks.items(): # 如果任务状态是pending或downloading,将其设置为pause if task.get("status") in ["pending", "downloading"]: task["status"] = "paused" task["message"] = "任务已暂停" self.state["download_tasks"] = download_tasks self.state["upload_tasks"] = loaded.get("upload_tasks", {}) self.state["remote_cache"] = loaded.get("remote_cache", {}) # 保存更新后的状态 self.save_state() except Exception as e: print(f"加载状态失败: {e}") def save_state(self): try: with open(STATE_FILE, 'w', encoding='utf-8') as f: json.dump(self.state, f, indent=2, ensure_ascii=False) except Exception as e: print(f"保存状态失败: {e}") global_state = GlobalState() def format_size(size_bytes): if size_bytes < 1024: return f"{size_bytes}B" elif size_bytes < 1024 * 1024: return f"{size_bytes / 1024:.1f}KB" elif size_bytes < 1024 * 1024 * 1024: return f"{size_bytes / (1024 * 1024):.1f}MB" else: return f"{size_bytes / (1024 * 1024 * 1024):.1f}GB" def get_dir_size(path): total = 0 if os.path.exists(path): for root, dirs, files in os.walk(path): for f in files: fp = os.path.join(root, f) try: total += os.path.getsize(fp) except: pass return total def get_file_count(path): count = 0 if os.path.exists(path): for root, dirs, files in os.walk(path): count += len(files) return count def estimate_model_size(model_id): cache_key = f"size_{model_id}" if cache_key in global_state.state["remote_cache"]: cached = global_state.state["remote_cache"][cache_key] if time.time() - cached.get("ts", 0) < 3600: return cached.get("size", "未知") if not HAS_MODELSCOPE: return "未知" try: api = HubApi() info = api.get_model(model_id) # 使用StorageSize获取模型大小,单位为字节 storage_size = info.get("StorageSize", 0) # 转换为人类可读格式 if storage_size > 0: if storage_size >= 1024 * 1024 * 1024: size_str = f"{storage_size / (1024 * 1024 * 1024):.2f} GB" elif storage_size >= 1024 * 1024: size_str = f"{storage_size / (1024 * 1024):.2f} MB" elif storage_size >= 1024: size_str = f"{storage_size / 1024:.2f} KB" else: size_str = f"{storage_size} B" else: size_str = "未知" global_state.state["remote_cache"][cache_key] = {"size": size_str, "ts": time.time()} global_state.save_state() return size_str except Exception as e: print(f"预估大小失败: {e}") # 缓存失败结果,避免重复请求不存在的模型 global_state.state["remote_cache"][cache_key] = {"size": "未知", "ts": time.time()} global_state.save_state() return "未知" def get_remote_file_count(model_id): cache_key = f"file_count_{model_id}" if cache_key in global_state.state["remote_cache"]: cached = global_state.state["remote_cache"][cache_key] if time.time() - cached.get("ts", 0) < 3600: return cached.get("count", 0) if not HAS_MODELSCOPE: return 0 try: api = HubApi() info = api.get_model(model_id) # 获取模型文件数量 - 参考ms-demo.py的实现 file_count = 0 if "ModelInfos" in info: # 优先检查safetensor文件 if "safetensor" in info["ModelInfos"] and "files" in info["ModelInfos"]["safetensor"]: file_count = len(info["ModelInfos"]["safetensor"]["files"]) else: # 如果没有safetensor,统计所有文件 for library in info["ModelInfos"].values(): if "files" in library and isinstance(library["files"], list): file_count += len(library["files"]) global_state.state["remote_cache"][cache_key] = {"count": file_count, "ts": time.time()} global_state.save_state() return file_count except Exception as e: print(f"获取文件数量失败: {e}") # 缓存失败结果,避免重复请求不存在的模型 global_state.state["remote_cache"][cache_key] = {"count": 0, "ts": time.time()} global_state.save_state() return 0 def scan_local_models(): local_dir = global_state.config.get("local", {}).get("default_model_path", "") if not local_dir or not os.path.exists(local_dir): return [] models = [] try: for org in os.listdir(local_dir): org_path = os.path.join(local_dir, org) if not os.path.isdir(org_path): continue for model_name in os.listdir(org_path): model_path = os.path.join(org_path, model_name) if not os.path.isdir(model_path): continue has_model = False for f in os.listdir(model_path): if f.endswith(('.safetensors', '.bin', '.pth', '.pt', '.ckpt')): has_model = True break if has_model: model_id = f"{org}/{model_name}" size = get_dir_size(model_path) files = get_file_count(model_path) uploaded = False for tid, task in global_state.state["upload_tasks"].items(): if task.get("model_id") == model_id and task.get("status") == "completed": uploaded = True break models.append({ "id": model_id, "path": model_path, "size": format_size(size), "size_bytes": size, "file_count": files, "uploaded": uploaded, "status": "已上传" if uploaded else "已下载" }) except Exception as e: print(f"扫描失败: {e}") return models def create_download_task(model_id, priority=0): task_id = f"dl_{int(time.time()*1000)}_{hash(model_id) % 10000}" with global_state.operation_lock: global_state.state["download_tasks"][task_id] = { "task_id": task_id, "model_id": model_id, "priority": priority, "status": "pending", "progress": 0, "total_files": 0, "completed_files": 0, "estimated_size": estimate_model_size(model_id), "message": "等待下载...", "retry_count": 0, "auto_upload": global_state.config.get("auto_upload_after_download", False), "auto_delete": global_state.config.get("auto_delete_after_upload", False), "start_time": None, "end_time": None } global_state.save_state() global_state.download_queue.put((-priority, task_id)) # 如果只有一个任务且没有正在下载的任务,立即开始下载 with global_state.operation_lock: if len(global_state.state["download_tasks"]) == 1 and len(global_state.active_downloads) == 0: # 立即处理队列中的任务 pass # 工作线程会自动处理 return task_id def create_upload_task(model_id, local_path): task_id = f"ul_{int(time.time()*1000)}_{hash(model_id) % 10000}" with global_state.operation_lock: global_state.state["upload_tasks"][task_id] = { "task_id": task_id, "model_id": model_id, "local_path": local_path, "status": "pending", "progress": 0, "total_files": 0, "completed_files": 0, "message": "等待上传...", "auto_delete": global_state.config.get("auto_delete_after_upload", False), "start_time": None, "end_time": None } global_state.save_state() global_state.upload_queue.put(task_id) return task_id def download_worker(): while True: try: neg_priority, task_id = global_state.download_queue.get(timeout=1) with global_state.operation_lock: if task_id not in global_state.state["download_tasks"]: continue task = global_state.state["download_tasks"][task_id] model_id = task["model_id"] # 检查是否有相同模型的任务正在下载 model_in_progress = False for active_id, active_model in global_state.active_downloads.items(): if active_model == model_id: model_in_progress = True break if model_in_progress: global_state.download_queue.put((neg_priority, task_id)) continue # 检查任务当前状态 - 只有pending状态的任务才会被执行 if task.get("status") != "pending": # 如果是paused状态,不放入队列,保持暂停 if task.get("status") == "paused": continue # 其他状态(如completed, failed)不处理 continue global_state.active_downloads[task_id] = model_id task["status"] = "downloading" task["start_time"] = time.time() task["message"] = "开始下载..." global_state.save_state() max_retries = 10 retry = task.get("retry_count", 0) try: # 在调用snapshot_download之前再次检查任务状态 with global_state.operation_lock: if task_id not in global_state.state["download_tasks"]: continue task = global_state.state["download_tasks"][task_id] if task.get("status") != "downloading": # 如果任务已经不是下载状态,将其放回队列 global_state.download_queue.put((neg_priority, task_id)) continue # 执行下载 success = download_model_impl(task_id, model_id) # 如果返回False,表示任务被暂停或删除,不继续处理 if not success: with global_state.operation_lock: if task_id in global_state.state["download_tasks"]: task = global_state.state["download_tasks"][task_id] if task.get("status") == "paused": print(f"任务 {task_id} 已暂停,不继续处理") continue # 跳过后续的成功处理逻辑 with global_state.operation_lock: task = global_state.state["download_tasks"].get(task_id) if task and task["status"] == "completed": if task.get("auto_upload"): local_path = os.path.join( global_state.config.get("local", {}).get("default_model_path", ""), model_id.replace("/", os.path.sep) ) if os.path.exists(local_path): create_upload_task(model_id, local_path) except Exception as e: print(f"下载线程异常: {e}") with global_state.operation_lock: if task_id in global_state.state["download_tasks"]: task = global_state.state["download_tasks"][task_id] # 检查任务是否已暂停,如果是,跳过重试逻辑 if task.get("status") == "paused": print(f"任务 {task_id} 已暂停,跳过重试") continue # 跳过重试逻辑 retry += 1 task["retry_count"] = retry task["message"] = f"下载失败 (尝试 {retry}/{max_retries}): {str(e)[:50]}" global_state.save_state() finally: # 无论下载成功、失败还是被中断,都从active_downloads中移除任务 with global_state.operation_lock: if task_id in global_state.active_downloads: del global_state.active_downloads[task_id] # 检查任务是否被暂停,如果是,将其放回队列 if task_id in global_state.state["download_tasks"]: task = global_state.state["download_tasks"][task_id] if task.get("status") == "paused": task["message"] = "任务已暂停" global_state.save_state() global_state.download_queue.put((neg_priority, task_id)) if retry >= max_retries: with global_state.operation_lock: if task_id in global_state.state["download_tasks"]: task = global_state.state["download_tasks"][task_id] task["status"] = "failed" task["end_time"] = time.time() task["message"] = f"下载失败: 已达到最大重试次数" global_state.save_state() continue time.sleep(5) global_state.download_queue.task_done() except queue.Empty: continue except Exception as e: print(f"下载线程异常: {e}") time.sleep(1) def download_process_impl(task_id, model_id, local_dir): """在子进程中执行下载的函数""" print(f"download process start: {model_id}") try: local_path = os.path.join(local_dir, model_id.replace("/", os.path.sep)) os.makedirs(os.path.dirname(local_path), exist_ok=True) # 实际下载 snapshot_download( model_id=model_id, cache_dir=local_dir, revision="master" ) print(f"download process success: {model_id} -> {local_dir}") if os.path.exists(local_path): file_count = get_file_count(local_path) size = get_dir_size(local_path) return { "success": True, "file_count": file_count, "size": size, "size_str": format_size(size) } else: raise FileNotFoundError(f"下载后未找到: {local_path}") except Exception as e: print(f"download process failed: {e}") return { "success": False, "error": str(e) } def download_model_impl(task_id, model_id): """管理下载进程""" local_dir = global_state.config.get("local", {}).get("default_model_path", "") if not local_dir: raise ValueError("未设置本地目录") with global_state.operation_lock: if task_id in global_state.state["download_tasks"]: task = global_state.state["download_tasks"][task_id] task["message"] = "正在下载..." global_state.save_state() if not HAS_MODELSCOPE: raise ImportError("modelscope未安装") # 创建并启动下载进程 process = multiprocessing.Process( target=download_process_impl, args=(task_id, model_id, local_dir), name=f"download_process_{task_id}" ) with global_state.operation_lock: global_state.download_processes[task_id] = process process.start() # 等待进程完成或任务被暂停 while process.is_alive(): time.sleep(1) # 每秒检查一次 with global_state.operation_lock: if task_id not in global_state.state["download_tasks"]: # 任务已被删除 process.terminate() process.join(timeout=5) if process.is_alive(): process.kill() return False task = global_state.state["download_tasks"][task_id] if task.get("status") == "paused": # 任务已被暂停 process.terminate() process.join(timeout=5) if process.is_alive(): process.kill() return False # 进程完成,获取结果 with global_state.operation_lock: if task_id in global_state.download_processes: del global_state.download_processes[task_id] if process.exitcode == 0: # 下载成功 local_path = os.path.join(local_dir, model_id.replace("/", os.path.sep)) if os.path.exists(local_path): file_count = get_file_count(local_path) size = get_dir_size(local_path) with global_state.operation_lock: if task_id in global_state.state["download_tasks"]: task = global_state.state["download_tasks"][task_id] task["status"] = "completed" task["progress"] = 100 task["total_files"] = file_count task["completed_files"] = file_count task["message"] = f"下载完成: {file_count}文件, {format_size(size)}" task["end_time"] = time.time() global_state.save_state() return True else: raise FileNotFoundError(f"下载后未找到: {local_path}") else: # 下载失败 raise RuntimeError(f"下载进程退出码: {process.exitcode}") def upload_worker(): while True: try: task_id = global_state.upload_queue.get(timeout=1) with global_state.operation_lock: if task_id not in global_state.state["upload_tasks"]: continue task = global_state.state["upload_tasks"][task_id] model_id = task["model_id"] local_path = task["local_path"] for active_id in global_state.active_uploads: if global_state.active_uploads[active_id] == model_id: global_state.upload_queue.put(task_id) continue global_state.active_uploads[task_id] = model_id task["status"] = "uploading" task["start_time"] = time.time() task["message"] = "开始上传..." global_state.save_state() try: # 检查任务是否被暂停 with global_state.operation_lock: if task_id not in global_state.state["upload_tasks"]: continue task = global_state.state["upload_tasks"][task_id] if task.get("status") == "paused": task["message"] = "任务已暂停" global_state.save_state() # 将任务放回队列,以便后续可以继续执行 global_state.upload_queue.put(task_id) continue upload_model_impl(task_id, model_id, local_path) with global_state.operation_lock: task = global_state.state["upload_tasks"].get(task_id) if task and task["status"] == "completed" and task.get("auto_delete"): try: if os.path.exists(local_path): shutil.rmtree(local_path) task["message"] += " (已自动删除本地)" except Exception as e: print(f"自动删除失败: {e}") except Exception as e: with global_state.operation_lock: if task_id in global_state.state["upload_tasks"]: task = global_state.state["upload_tasks"][task_id] task["status"] = "failed" task["end_time"] = time.time() task["message"] = f"上传失败: {str(e)[:50]}" global_state.save_state() with global_state.operation_lock: if task_id in global_state.active_uploads: del global_state.active_uploads[task_id] global_state.upload_queue.task_done() except queue.Empty: continue except Exception as e: print(f"上传线程异常: {e}") time.sleep(1) def upload_model_impl(task_id, model_id, local_path): if not HAS_PYCSGHUB: raise ImportError("pycsghub未安装") csghub_config = global_state.config.get("csghub", {}) if not csghub_config.get("base_url") or not csghub_config.get("token"): raise ValueError("CSGHUB配置不完整") with global_state.operation_lock: if task_id in global_state.state["upload_tasks"]: task = global_state.state["upload_tasks"][task_id] task["message"] = "正在连接..." global_state.save_state() csg_api = CsgHubApi() repo_id = f"root/{model_id.split('/')[1]}" create_repo( api=csg_api, repo_id=repo_id, repo_type=csghub_config["repo_type"], revision=csghub_config["revision"], endpoint=csghub_config["base_url"], token=csghub_config["token"] ) file_count = get_file_count(local_path) with global_state.operation_lock: if task_id in global_state.state["upload_tasks"]: task = global_state.state["upload_tasks"][task_id] task["total_files"] = file_count task["message"] = f"上传中: 0/{file_count}" global_state.save_state() upload_large_folder_internal( repo_id=repo_id, local_path=local_path, repo_type=csghub_config["repo_type"], revision=csghub_config["revision"], endpoint=csghub_config["base_url"], token=csghub_config["token"], num_workers=1, print_report=False, allow_patterns=None, ignore_patterns=None, print_report_every=1, ) with global_state.operation_lock: if task_id in global_state.state["upload_tasks"]: task = global_state.state["upload_tasks"][task_id] task["status"] = "completed" task["progress"] = 100 task["completed_files"] = file_count task["message"] = f"上传完成: {file_count}文件" task["end_time"] = time.time() global_state.save_state() def get_downloads_data(): with global_state.operation_lock: tasks = list(global_state.state["download_tasks"].values()) if not tasks: return [] tasks.sort(key=lambda x: ( -x.get("priority", 0), 0 if x["status"] == "downloading" else 1, x.get("start_time") or 0 )) data = [] for t in tasks: total_files = t.get("total_files", 0) completed_files = t.get("completed_files", 0) files_info = f"{completed_files}/{total_files}" if total_files > 0 else "0/0" data.append([ t.get("task_id", ""), t.get("model_id", ""), t.get("status", ""), t.get("progress", 0), t.get("estimated_size", "未知"), files_info, t.get("message", "") ]) return data def get_uploads_data(): with global_state.operation_lock: tasks = list(global_state.state["upload_tasks"].values()) if not tasks: return [] data = [] for t in tasks: data.append([ t.get("model_id", ""), t.get("status", ""), t.get("progress", 0), t.get("completed_files", 0), t.get("total_files", 0), t.get("message", "") ]) return data def get_local_models_data(): models = scan_local_models() if not models: return [] models.sort(key=lambda x: (0 if x["uploaded"] else 1, -x["size_bytes"])) data = [] for m in models: data.append([ m.get("id", ""), m.get("status", ""), m.get("size", ""), m.get("file_count", 0), "是" if m.get("uploaded") else "否" ]) return data # 创建综合任务表格数据 def get_combined_tasks_data(): # 获取所有数据 downloads = get_downloads_data() uploads = get_uploads_data() local_models = get_local_models_data() # 创建综合数据列表 combined = [] # 添加下载任务 for task in downloads: combined.append({ "type": "download", "id": task[0], # 任务ID "model_id": task[1], "status": task[2], "progress": task[3], "size": task[4], # 这个已经包含了远程模型大小 "local_size": "", # 本地大小待下载完成后更新 "file_count": task[5].split("/")[-1] if task[5] else "0", "message": task[6] }) # 添加上传任务 for task in uploads: model_id = task[0] # 获取远程模型大小用于比较 remote_size = estimate_model_size(model_id) combined.append({ "type": "upload", "id": model_id, # 模型ID "model_id": model_id, "status": task[1], "progress": task[2], "size": remote_size, # 远程模型大小 "local_size": "", # 本地大小 "file_count": f"{task[3]}/{task[4]}", "message": task[5] }) # 添加本地模型 for model in local_models: model_id = model[0] # 获取远程模型大小用于比较 remote_size = estimate_model_size(model_id) combined.append({ "type": "local", "id": model_id, # 模型ID "model_id": model_id, "status": model[1], "progress": "", "size": remote_size, # 远程模型大小 "local_size": model[2], # 本地模型大小 "file_count": model[3], "message": f"已上传,本地大小: {model[2]}, 远程大小: {remote_size}" }) # 转换为表格数据格式 data = [] for item in combined: data.append([ item["model_id"], item["type"], # 任务类型 item["status"], item["progress"], item["size"], # 远程大小 item["local_size"], # 本地大小 item["file_count"], item["message"] ]) return data def refresh_all(): return ( get_downloads_data(), get_combined_tasks_data() ) def start_workers(): for i in range(global_state.config.get("download", {}).get("max_concurrent", 1)): t = threading.Thread(target=download_worker, daemon=True, name=f"dl_worker_{i}") t.start() for i in range(global_state.config.get("upload", {}).get("num_workers", 1)): t = threading.Thread(target=upload_worker, daemon=True, name=f"ul_worker_{i}") t.start() def state_cleaner(): while True: time.sleep(60) with global_state.operation_lock: for task_type, tasks in [("download", global_state.state["download_tasks"]), ("upload", global_state.state["upload_tasks"])]: to_remove = [] for tid, task in tasks.items(): if task.get("status") in ["completed", "failed", "cancelled"]: if time.time() - task.get("end_time", 0) > 300: to_remove.append(tid) for tid in to_remove: del tasks[tid] global_state.save_state() t = threading.Thread(target=state_cleaner, daemon=True, name="state_cleaner") t.start() def create_interface(): # 创建主题切换器 with gr.Blocks(title="模型下载管理器", theme=gr.themes.Base()) as app: gr.Markdown("# 模型下载管理器") # 创建左侧按钮组和主内容区域 with gr.Row(): # 左侧按钮列 with gr.Column(scale=1, min_width=150): config_btn = gr.Button("配置", variant="primary", size="lg") download_btn = gr.Button("下载", variant="secondary", size="lg") manage_btn = gr.Button("管理", variant="secondary", size="lg") # 右侧主内容区域 with gr.Column(scale=4): # 创建状态变量来跟踪当前选中的标签 current_tab = gr.State("config_tab") # 根据当前标签显示不同内容 with gr.Column(visible=True) as config_tab_content: gr.Markdown("## 本地路径配置") local_dir_input = gr.Textbox( label="本地模型目录", value=global_state.config.get("local", {}).get("default_model_path", ""), placeholder="请输入本地模型存储目录" ) with gr.Row(): config_path_btn = gr.Button("配置路径", variant="primary") clear_tasks_btn = gr.Button("删除之前任务列表", variant="secondary") config_result = gr.Textbox(label="配置结果", interactive=False) gr.Markdown("## 高级配置") auto_upload_cb = gr.Checkbox( label="下载后自动上传", value=global_state.config.get("auto_upload_after_download", False) ) auto_delete_cb = gr.Checkbox( label="上传后自动删除", value=global_state.config.get("auto_delete_after_upload", False) ) update_config_btn = gr.Button("更新高级配置") with gr.Column(visible=False) as download_tab_content: gr.Markdown("## 下载模型") model_id_input = gr.Textbox( label="模型ID", placeholder="例如: Qwen/Qwen2.5-7B-Instruct" ) add_dl_btn = gr.Button("添加下载任务", variant="primary") gr.Markdown("## 下载任务列表") selected_task = gr.Textbox(label="选中的任务ID", interactive=False) task_result = gr.Textbox(label="操作结果", interactive=False) with gr.Row(): refresh_btn = gr.Button("刷新任务列表") pause_btn = gr.Button("暂停任务") resume_btn = gr.Button("恢复任务") delete_task_btn = gr.Button("删除任务") with gr.Row(): move_top_btn = gr.Button("任务置顶") move_up_btn = gr.Button("任务上移") move_down_btn = gr.Button("任务下移") downloads_table = gr.Dataframe( headers=["任务ID", "模型ID", "状态", "进度", "预估大小", "文件数", "消息"], datatype=["str", "str", "str", "number", "str", "number", "str"], value=get_downloads_data(), interactive=False ) with gr.Column(visible=False) as manage_tab_content: # 顶部扫描模型按钮 gr.Markdown("## 模型任务管理") scan_btn = gr.Button("扫描模型", variant="primary", size="lg") scan_result = gr.Textbox(label="扫描结果", interactive=False, container=False) # 综合任务管理表格 combined_tasks_table = gr.Dataframe( headers=["模型ID", "任务类型", "状态", "进度", "远程大小", "本地大小", "文件数", "消息"], datatype=["str", "str", "str", "str", "str", "str", "str", "str"], value=get_combined_tasks_data(), interactive=False ) # 表格操作按钮 with gr.Row(): selected_model_id = gr.Textbox(label="选中的模型ID", interactive=False) # 综合任务表格选择 def select_combined_task(evt: gr.SelectData): try: # 获取当前显示的所有任务数据 tasks = get_combined_tasks_data() # 处理不同格式的索引 row_idx = None if isinstance(evt.index, (list, tuple)) and len(evt.index) >= 1: row_idx = evt.index[0] elif isinstance(evt.index, int): row_idx = evt.index # 获取选中行的模型ID if isinstance(row_idx, int) and 0 <= row_idx < len(tasks): return tasks[row_idx][0] # 返回模型ID except Exception as e: print(f"选择任务时出错: {e}") return "" combined_tasks_table.select(select_combined_task, outputs=[selected_model_id]) with gr.Row(): upload_task_btn = gr.Button("添加上传任务", variant="primary") delete_model_btn = gr.Button("删除模型", variant="secondary") upload_result = gr.Textbox(label="操作结果", interactive=False, container=False) # 批量操作按钮 with gr.Row(): delete_failed_tasks_btn = gr.Button("删除失败任务", variant="secondary") delete_completed_tasks_btn = gr.Button("删除已完成任务", variant="secondary") delete_tasks_result = gr.Textbox(label="批量操作结果", interactive=False, container=False) # 配置保存 - 仅保存路径,不清空任务 def save_config_path(d): if d: # 更新本地模型目录配置 global_state.config.get("local", {}).update({"default_model_path": d}) global_state.save_config() return "目录已设置" return "目录无效" # 清空任务列表 def clear_tasks(): with global_state.operation_lock: # 清空下载任务 global_state.state["download_tasks"].clear() # 清空上传任务 global_state.state["upload_tasks"].clear() # 重置队列 global_state.download_queue = queue.PriorityQueue() global_state.upload_queue = queue.Queue() # 清空活动任务 global_state.active_downloads.clear() global_state.active_uploads.clear() # 终止所有正在运行的下载进程 for task_id, process in list(global_state.download_processes.items()): try: process.terminate() process.join(timeout=5) if process.is_alive(): process.kill() except Exception as e: print(f"终止进程失败: {e}") global_state.download_processes.clear() return "任务列表已清空" # 配置路径按钮点击事件 config_path_btn.click( fn=save_config_path, inputs=[local_dir_input], outputs=[config_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) # 删除任务列表按钮点击事件 clear_tasks_btn.click( fn=clear_tasks, inputs=None, outputs=[config_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) # 高级配置 update_config_btn.click( fn=lambda au, ad: (global_state.config.update({ "auto_upload_after_download": au, "auto_delete_after_upload": ad }), global_state.save_config(), "配置已更新"), inputs=[auto_upload_cb, auto_delete_cb], outputs=[config_result] ) # 添加下载任务 def add_download(mid, au, ad): local_dir = global_state.config.get("local", {}).get("default_model_path", "") if not local_dir: return "请先设置本地目录" if mid: global_state.config.update({ "auto_upload_after_download": au, "auto_delete_after_upload": ad }) global_state.save_config() # 获取远程文件数量 remote_file_count = get_remote_file_count(mid) # 默认优先级为0 task_id = create_download_task(mid, 0) # 更新任务的总文件数 with global_state.operation_lock: if task_id in global_state.state["download_tasks"]: task = global_state.state["download_tasks"][task_id] task["total_files"] = remote_file_count global_state.save_state() return "任务已添加" add_dl_btn.click( fn=add_download, inputs=[model_id_input, auto_upload_cb, auto_delete_cb], outputs=None ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) # 添加上传任务 def add_upload(mid): local_dir = global_state.config.get("local", {}).get("default_model_path", "") if not local_dir: return "请先设置本地目录" if mid: create_upload_task(mid, os.path.join( local_dir, mid.replace("/", os.path.sep) )) return "任务已添加" return "请先选择一个模型" upload_task_btn.click( fn=add_upload, inputs=[selected_model_id], outputs=[upload_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) # 删除模型 def delete_model(mid): local_dir = global_state.config.get("local", {}).get("default_model_path", "") if not local_dir: return "请先设置本地目录" if mid: model_path = os.path.join(local_dir, mid.replace("/", os.path.sep)) if os.path.exists(model_path): shutil.rmtree(model_path) return "已删除" return "模型不存在" return "请先选择一个模型" delete_model_btn.click( fn=delete_model, inputs=[selected_model_id], outputs=[upload_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) # 暂停任务 def pause_download(task_id): if not task_id: return "请先选择一个任务" with global_state.operation_lock: if task_id in global_state.state["download_tasks"]: task = global_state.state["download_tasks"][task_id] current_status = task.get("status") if current_status == "downloading": # 更新任务状态为暂停 task["status"] = "paused" task["message"] = "任务已暂停" # 从active_downloads中移除 if task_id in global_state.active_downloads: del global_state.active_downloads[task_id] # 保存状态,download_model_impl会检测到暂停状态并终止进程 global_state.save_state() return f"任务已暂停: {task_id}" elif current_status == "pending": # 对于等待中的任务,直接标记为暂停 task["status"] = "paused" task["message"] = "任务已暂停" global_state.save_state() return f"任务已暂停: {task_id}" else: return f"任务当前状态: {current_status},无法暂停" return "未找到该任务或任务不在可暂停状态" # 恢复任务 def resume_download(task_id): with global_state.operation_lock: if task_id in global_state.state["download_tasks"]: task = global_state.state["download_tasks"][task_id] if task.get("status") == "paused": task["status"] = "pending" task["message"] = "等待下载..." # 将任务放回队列 global_state.download_queue.put((-task.get("priority", 0), task_id)) global_state.save_state() return f"任务已恢复: {task_id}" return "未找到该任务或任务不在暂停状态" pause_btn.click( fn=pause_download, inputs=[selected_task], outputs=[task_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) # 恢复任务按钮 resume_btn.click( fn=resume_download, inputs=[selected_task], outputs=[task_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) # 添加表格行点击选择功能 def select_task(evt: gr.SelectData): # 简化的任务选择函数 try: # 直接从evt.value获取行数据(如果是整行选择) if isinstance(evt.value, list) and len(evt.value) > 0: # 如果evt.value是列表,第一个元素就是任务ID task_id = evt.value[0] return task_id, f"已选中任务: {task_id}" # 如果是单元格选择,使用索引获取行数据 row_idx = 0 if isinstance(evt.index, tuple): row_idx = evt.index[0] elif isinstance(evt.index, list): row_idx = evt.index[0] else: row_idx = evt.index downloads = get_downloads_data() if 0 <= row_idx < len(downloads): row_data = downloads[row_idx] if isinstance(row_data, list) and len(row_data) > 0: task_id = row_data[0] return task_id, f"已选中任务: {task_id}" except Exception as e: print(f"选择任务错误: {str(e)}") return "", "未选中任何任务" downloads_table.select( fn=select_task, inputs=[], outputs=[selected_task, task_result] ) # 根据任务ID删除任务 def delete_download_task(task_id): with global_state.operation_lock: if task_id in global_state.state["download_tasks"]: # 如果任务正在下载,先从active_downloads中移除 if task_id in global_state.active_downloads: del global_state.active_downloads[task_id] # 终止相关的下载进程 if task_id in global_state.download_processes: process = global_state.download_processes[task_id] process.terminate() try: process.join(timeout=3) if process.is_alive(): process.kill() except Exception: pass finally: # 从下载进程字典中移除 del global_state.download_processes[task_id] # 删除任务记录 del global_state.state["download_tasks"][task_id] global_state.save_state() return f"已删除任务: {task_id}" return "未找到该任务" # 任务置顶 def move_task_top(task_id): with global_state.operation_lock: if task_id not in global_state.state["download_tasks"]: return "未找到该任务" # 找到当前最高优先级 max_priority = max((task.get("priority", 0) for task in global_state.state["download_tasks"].values()), default=0) # 设置当前任务为最高优先级+1 task = global_state.state["download_tasks"][task_id] old_priority = task.get("priority", 0) task["priority"] = max_priority + 1 # 如果任务正在下载,将其暂停 if task.get("status") == "downloading": task["status"] = "paused" task["message"] = "任务已暂停" # 从active_downloads中移除 if task_id in global_state.active_downloads: del global_state.active_downloads[task_id] # 终止正在运行的下载进程 if task_id in global_state.download_processes: process = global_state.download_processes[task_id] process.terminate() try: process.join(timeout=3) if process.is_alive(): process.kill() except Exception: pass finally: del global_state.download_processes[task_id] global_state.save_state() return f"任务 {task_id} 已置顶" # 任务上移 def move_task_up(task_id): with global_state.operation_lock: if task_id not in global_state.state["download_tasks"]: return "未找到该任务" # 获取所有任务并按优先级排序(降序) tasks = sorted(global_state.state["download_tasks"].values(), key=lambda x: (-x.get("priority", 0), x.get("start_time", 0))) # 找到当前任务的索引 current_index = -1 for i, task in enumerate(tasks): if task["task_id"] == task_id: current_index = i break if current_index <= 0: return "任务已经在最顶部" # 与前一个任务交换优先级 current_task = tasks[current_index] prev_task = tasks[current_index - 1] # 交换优先级 current_priority = current_task.get("priority", 0) prev_priority = prev_task.get("priority", 0) current_task["priority"] = prev_priority prev_task["priority"] = current_priority # 如果当前任务正在下载,且前面的任务处于pending状态,将当前任务暂停 if current_task.get("status") == "downloading" and prev_task.get("status") == "pending": current_task["status"] = "paused" current_task["message"] = "任务已暂停" # 从active_downloads中移除 if task_id in global_state.active_downloads: del global_state.active_downloads[task_id] # 终止正在运行的下载进程 if task_id in global_state.download_processes: process = global_state.download_processes[task_id] process.terminate() try: process.join(timeout=3) if process.is_alive(): process.kill() except Exception: pass finally: del global_state.download_processes[task_id] global_state.save_state() return f"任务 {task_id} 已上移" # 任务下移 def move_task_down(task_id): with global_state.operation_lock: if task_id not in global_state.state["download_tasks"]: return "未找到该任务" # 获取所有任务并按优先级排序(降序) tasks = sorted(global_state.state["download_tasks"].values(), key=lambda x: (-x.get("priority", 0), x.get("start_time", 0))) # 找到当前任务的索引 current_index = -1 for i, task in enumerate(tasks): if task["task_id"] == task_id: current_index = i break if current_index >= len(tasks) - 1: return "任务已经在最底部" # 与后一个任务交换优先级 current_task = tasks[current_index] next_task = tasks[current_index + 1] # 交换优先级 current_priority = current_task.get("priority", 0) next_priority = next_task.get("priority", 0) current_task["priority"] = next_priority next_task["priority"] = current_priority global_state.save_state() return f"任务 {task_id} 已下移" delete_task_btn.click( fn=delete_download_task, inputs=[selected_task], outputs=[task_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) # 任务置顶、上移、下移按钮 move_top_btn.click( fn=move_task_top, inputs=[selected_task], outputs=[task_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) move_up_btn.click( fn=move_task_up, inputs=[selected_task], outputs=[task_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) move_down_btn.click( fn=move_task_down, inputs=[selected_task], outputs=[task_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) # 刷新 refresh_btn.click( fn=refresh_all, outputs=[downloads_table, combined_tasks_table] ) # 扫描模型 def scan_models(): local_dir = global_state.config.get("local", {}).get("default_model_path", "") if not local_dir: return "请先设置本地目录", [], [] models = scan_local_models() return f"已扫描到 {len(models)} 个模型", get_downloads_data(), get_combined_tasks_data() scan_btn.click( fn=scan_models, outputs=[scan_result, downloads_table, combined_tasks_table] ) # 删除失败任务 def delete_failed_tasks(): with global_state.operation_lock: # 删除失败的下载任务 dl_to_delete = [] for task_id, task in global_state.state["download_tasks"].items(): if task.get("status") == "failed": dl_to_delete.append(task_id) for task_id in dl_to_delete: del global_state.state["download_tasks"][task_id] # 删除失败的上传任务 ul_to_delete = [] for task_id, task in global_state.state["upload_tasks"].items(): if task.get("status") == "failed": ul_to_delete.append(task_id) for task_id in ul_to_delete: del global_state.state["upload_tasks"][task_id] global_state.save_state() return f"已删除 {len(dl_to_delete)} 个失败的下载任务和 {len(ul_to_delete)} 个失败的上传任务" delete_failed_tasks_btn.click( fn=delete_failed_tasks, outputs=[delete_tasks_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) # 删除已完成任务 def delete_completed_tasks(): with global_state.operation_lock: # 删除已完成的下载任务 dl_to_delete = [] for task_id, task in global_state.state["download_tasks"].items(): if task.get("status") == "completed": dl_to_delete.append(task_id) for task_id in dl_to_delete: del global_state.state["download_tasks"][task_id] # 删除已完成的上传任务 ul_to_delete = [] for task_id, task in global_state.state["upload_tasks"].items(): if task.get("status") == "completed": ul_to_delete.append(task_id) for task_id in ul_to_delete: del global_state.state["upload_tasks"][task_id] global_state.save_state() return f"已删除 {len(dl_to_delete)} 个已完成的下载任务和 {len(ul_to_delete)} 个已完成的上传任务" delete_completed_tasks_btn.click( fn=delete_completed_tasks, outputs=[delete_tasks_result] ).then(fn=refresh_all, outputs=[downloads_table, combined_tasks_table]) # 切换标签页的函数 def switch_to_config_tab(): # 当切换到配置标签时,刷新本地目录输入框的值 local_dir = global_state.config.get("local", {}).get("default_model_path", "") return ( "config_tab", gr.Column(visible=True), # config_tab_content gr.Column(visible=False), # download_tab_content gr.Column(visible=False), # manage_tab_content gr.Button(variant="primary"), # config_btn gr.Button(variant="secondary"), # download_btn gr.Button(variant="secondary"), # manage_btn local_dir # 更新本地目录输入框 ) def switch_to_download_tab(): # 保持与config_tab相同的返回值数量 return ( "download_tab", gr.Column(visible=False), # config_tab_content gr.Column(visible=True), # download_tab_content gr.Column(visible=False), # manage_tab_content gr.Button(variant="secondary"), # config_btn gr.Button(variant="primary"), # download_btn gr.Button(variant="secondary"), # manage_btn global_state.config.get("local", {}).get("default_model_path", "") # 保持数量一致 ) def switch_to_manage_tab(): # 保持与config_tab相同的返回值数量 return ( "manage_tab", gr.Column(visible=False), # config_tab_content gr.Column(visible=False), # download_tab_content gr.Column(visible=True), # manage_tab_content gr.Button(variant="secondary"), # config_btn gr.Button(variant="secondary"), # download_btn gr.Button(variant="primary"), # manage_btn global_state.config.get("local", {}).get("default_model_path", "") # 保持数量一致 ) # 左侧按钮点击事件 config_btn.click( fn=switch_to_config_tab, outputs=[current_tab, config_tab_content, download_tab_content, manage_tab_content, config_btn, download_btn, manage_btn, local_dir_input] ) download_btn.click( fn=switch_to_download_tab, outputs=[current_tab, config_tab_content, download_tab_content, manage_tab_content, config_btn, download_btn, manage_btn, local_dir_input] ) manage_btn.click( fn=switch_to_manage_tab, outputs=[current_tab, config_tab_content, download_tab_content, manage_tab_content, config_btn, download_btn, manage_btn, local_dir_input] ) # 移除了主题切换功能,因为在Gradio 6.0中不支持将Blocks作为输出组件 # 初始化完成 # 应用加载时刷新数据和配置 def load_all(): # 刷新任务数据 downloads_data, combined_data = refresh_all() # 获取最新的本地目录配置 local_dir = global_state.config.get("local", {}).get("default_model_path", "") return downloads_data, combined_data, local_dir app.load(fn=load_all, outputs=[downloads_table, combined_tasks_table, local_dir_input]) return app def main(): print("=" * 50) print("模型下载管理器") print("=" * 50) print(f"本地目录: {global_state.config.get('local', {}).get('default_model_path', '未设置')}") print(f"CSGHUB地址: {global_state.config.get('csghub', {}).get('base_url', '未设置')}") print(f"modelscope: {'已安装' if HAS_MODELSCOPE else '未安装'}") print(f"pycsghub: {'已安装' if HAS_PYCSGHUB else '未安装'}") start_workers() app = create_interface() app.launch( server_name="0.0.0.0", server_port=7865, share=False ) if __name__ == "__main__": main()