#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import sys import time import json import shutil import glob import requests import subprocess from pathlib import Path from typing import Optional, Callable, Dict, Any # 尝试导入modelscope和pycsghub try: from modelscope.hub.snapshot_download import snapshot_download has_modelscope = True except ImportError: print("Warning: modelscope not installed, download functionality will be limited") has_modelscope = False try: from pycsghub.upload_large_folder.main import upload_large_folder_internal, create_repo from pycsghub.csghub_api import CsgHubApi has_pycsghub = True except ImportError: print("Warning: pycsghub not installed, upload functionality will be limited") has_pycsghub = False class ModelManager: """模型管理器,用于下载和上传模型""" def __init__(self): """初始化模型管理器""" self.default_download_path = os.path.expanduser("~/models") self.csghub_config = { "base_url": "http://10.17.27.227:4997", "token": "f5dad38a9426410aa861155cd184f84a", "repo_type": "model", "revision": "main" } # 确保默认下载路径存在 os.makedirs(self.default_download_path, exist_ok=True) def download_model(self, model_id: str, local_path: str = None, progress_callback: Optional[Callable] = None, cancel_check: Optional[Callable] = None) -> str: """ 从ModelScope下载模型 Args: model_id: 模型ID,格式为"组织/模型名" local_path: 本地保存路径,默认为~/models progress_callback: 进度回调函数,接收(progress, detail)参数 cancel_check: 取消检查函数,返回True表示已取消 Returns: str: 下载的模型路径 Raises: Exception: 下载失败时抛出异常 """ if not model_id: raise ValueError("模型ID不能为空") # 设置本地路径 if not local_path: local_path = self.default_download_path # 确保本地路径存在 os.makedirs(local_path, exist_ok=True) # 模型保存的完整路径 model_path = os.path.join(local_path, model_id) # 打印下载信息到终端 print(f"\n{'='*50}") print(f"开始下载模型") print(f"模型ID: {model_id}") print(f"本地路径: {model_path}") print(f"{'='*50}") # 如果模型已存在,先删除 if os.path.exists(model_path): print(f"模型已存在,正在删除: {model_path}") shutil.rmtree(model_path) # 调用回调函数 if progress_callback: progress_callback(0, f"开始下载模型 {model_id}") try: if has_modelscope: # 使用modelscope下载 print(f"使用modelscope下载模型: {model_id}") print(f"下载目标路径: {model_path}") # 检查是否已取消 if cancel_check and cancel_check(): print(f"下载任务已取消: {model_id}") if progress_callback: progress_callback(-1, {"error": "下载已取消"}) raise Exception("下载已取消") # 注意:modelscope不直接支持进度回调,我们将在下载后计算文件数量 # 使用进程池执行snapshot_download,以便可以强制终止 import multiprocessing # 定义下载函数 def download_func(): try: return snapshot_download( model_id=model_id, cache_dir=local_path, revision="master" ) except Exception as e: print(f"下载出错: {e}") raise # 创建进程 process = multiprocessing.Process(target=download_func) process.daemon = True process.start() # 定期检查是否已取消 while process.is_alive(): if cancel_check and cancel_check(): print(f"下载任务已取消: {model_id}") # 强制终止进程 process.terminate() process.join(timeout=1) if process.is_alive(): process.kill() if progress_callback: progress_callback(-1, {"error": "下载已取消"}) raise Exception("下载已取消") time.sleep(0.1) # 检查进程是否正常退出 if process.exitcode != 0: raise Exception("下载进程异常退出") print(f"modelscope下载完成,正在处理文件...") # 下载完成后,计算文件数量并更新进度 if os.path.exists(model_path): # 获取文件列表 all_files = [] for root, dirs, files in os.walk(model_path): # 检查是否已取消 if cancel_check and cancel_check(): print(f"下载任务已取消: {model_id}") if progress_callback: progress_callback(-1, {"error": "下载已取消"}) raise Exception("下载已取消") for file in files: all_files.append(os.path.join(root, file)) file_count = len(all_files) print(f"发现 {file_count} 个文件") # 按文件数量更新进度 for i, file_path in enumerate(all_files): # 检查是否已取消 if cancel_check and cancel_check(): print(f"下载任务已取消: {model_id}") if progress_callback: progress_callback(-1, {"error": "下载已取消"}) raise Exception("下载已取消") progress = int((i + 1) / file_count * 100) rel_path = os.path.relpath(file_path, model_path) file_size = os.path.getsize(file_path) print(f"[{progress}%] 已下载: {rel_path} ({self.get_dir_size(file_path)})") if progress_callback: progress_callback(progress, { "file_count": i + 1, "total_files": file_count, "current_file": rel_path, "file_size": file_size }) time.sleep(0.05) # 减少延迟 else: # 直接使用modelscope下载(不使用模拟模式) print(f"modelscope未安装,无法下载模型: {model_id}") raise Exception("modelscope未安装,无法下载模型") if progress_callback: progress_callback(100, { "file_count": file_count, "total_files": file_count, "current_file": "完成", "message": f"模型 {model_id} 下载完成" }) # 打印下载完成信息 print(f"\n{'='*50}") print(f"模型下载完成!") print(f"模型ID: {model_id}") print(f"下载路径: {model_path}") print(f"文件数量: {file_count}") print(f"{'='*50}") return model_path except Exception as e: error_msg = str(e) if progress_callback: progress_callback(-1, {"error": error_msg}) # 打印错误信息 print(f"\n{'='*50}") print(f"下载失败!") print(f"模型ID: {model_id}") print(f"错误信息: {error_msg}") print(f"{'='*50}") raise Exception(f"下载模型 {model_id} 失败: {error_msg}") def upload_model(self, local_path: str, repo_id: str, create_repo_flag: bool = True, progress_callback: Optional[Callable] = None) -> Dict[str, Any]: """ 上传模型到CsgHub Args: local_path: 本地模型路径 repo_id: 仓库ID create_repo_flag: 是否创建仓库 progress_callback: 进度回调函数,接收(progress, detail)参数 Returns: Dict[str, Any]: 上传结果 Raises: Exception: 上传失败时抛出异常 """ if not local_path or not os.path.exists(local_path): raise ValueError(f"本地路径 {local_path} 不存在") if not repo_id: raise ValueError("仓库ID不能为空") # 调用回调函数 if progress_callback: progress_callback(0, f"开始上传模型到仓库 {repo_id}") try: # 首先获取所有文件列表,用于计算进度 all_files = [] for root, dirs, files in os.walk(local_path): for file in files: file_path = os.path.join(root, file) all_files.append(file_path) file_count = len(all_files) if file_count == 0: raise ValueError(f"本地路径 {local_path} 中没有文件") if has_pycsghub: # 使用pycsghub上传 csg_api = CsgHubApi() use_full_repo_id = f"root/{repo_id}" # 创建仓库 if create_repo_flag: if progress_callback: progress_callback(5, "正在创建仓库...") create_repo( api=csg_api, repo_id=use_full_repo_id, repo_type=self.csghub_config["repo_type"], revision=self.csghub_config["revision"], endpoint=self.csghub_config["base_url"], token=self.csghub_config["token"] ) # 上传模型 if progress_callback: progress_callback(10, f"准备上传 {file_count} 个文件...") # 创建一个自定义的进度回调函数 def custom_upload_callback(current_file_index, current_file_path, total_files): """自定义上传进度回调""" progress = int((current_file_index + 1) / total_files * 90) + 10 # 10% - 100% rel_path = os.path.relpath(current_file_path, local_path) if progress_callback: progress_callback(progress, f"上传中 {current_file_index + 1}/{total_files}: {rel_path}") # 执行上传 - 注意:pycsghub可能不直接支持文件级别的进度回调 # 这里我们将在上传完成后模拟文件级别的进度 upload_large_folder_internal( repo_id=use_full_repo_id, local_path=local_path, repo_type=self.csghub_config["repo_type"], revision=self.csghub_config["revision"], endpoint=self.csghub_config["base_url"], token=self.csghub_config["token"], allow_patterns=None, ignore_patterns=None, num_workers=1, print_report=False, print_report_every=1, ) # 上传完成后,模拟文件级别的进度更新 for i, file_path in enumerate(all_files): progress = int((i + 1) / file_count * 90) + 10 # 10% - 100% rel_path = os.path.relpath(file_path, local_path) if progress_callback: progress_callback(progress, f"已上传 {i + 1}/{file_count}: {rel_path}") time.sleep(0.05) # 模拟处理延迟 else: # 直接使用pycsghub上传(不使用模拟模式) print(f"pycsghub未安装,无法上传模型: {repo_id}") raise Exception("pycsghub未安装,无法上传模型") if progress_callback: progress_callback(100, f"模型上传完成,仓库ID: {repo_id},共上传 {file_count} 个文件") return { "success": True, "repo_id": repo_id, "file_count": file_count, "message": f"模型上传成功,共上传 {file_count} 个文件" } except Exception as e: if progress_callback: progress_callback(-1, f"上传失败: {str(e)}") raise Exception(f"上传模型失败: {str(e)}") def list_models(self, local_path: str = None) -> list: """ 列出本地模型 Args: local_path: 本地模型路径,默认为~/models Returns: list: 模型列表 """ if not local_path: local_path = self.default_download_path if not os.path.exists(local_path): print(f"Model path does not exist: {local_path}") return [] models = [] try: print(f"Listing models from: {local_path}") items = os.listdir(local_path) print(f"Found {len(items)} items in directory") # 遍历一级目录 for item in items: item_path = os.path.join(local_path, item) print(f"Checking item: {item} (type: {'dir' if os.path.isdir(item_path) else 'file'})") if os.path.isdir(item_path): # 检查一级目录下的二级子目录 try: sub_items = os.listdir(item_path) print(f" Found {len(sub_items)} sub-items in {item}") for sub_item in sub_items: sub_item_path = os.path.join(item_path, sub_item) if os.path.isdir(sub_item_path): print(f" Checking sub-directory: {sub_item}") # 检查是否有 README.md has_readme = os.path.exists(os.path.join(sub_item_path, "README.md")) # 检查是否有 .safetensors 或 .bin 文件 has_safetensors_or_bin = False try: for file in os.listdir(sub_item_path): if file.endswith('.safetensors') or file.endswith('.bin'): has_safetensors_or_bin = True break except Exception as e: print(f" Error checking files in {sub_item_path}: {e}") continue print(f" - README.md: {has_readme}") print(f" - has .safetensors or .bin: {has_safetensors_or_bin}") # 判断是否为模型目录(必须有README.md和.safetensors/.bin文件) if has_readme and has_safetensors_or_bin: # 获取模型信息,使用前端期望的字段名 model_info = { "id": sub_item, # 使用二级目录名作为id "path": sub_item_path, "size": self.get_dir_size(sub_item_path), "status": "downloaded", # 默认状态 "downloadTime": self.get_dir_creation_time(sub_item_path), "uploadTime": None, "upload_repo_id": None, "file_count": 0 # 计算文件数量 } # 计算文件数量 file_count = 0 for root, dirs, files in os.walk(sub_item_path): file_count += len(files) model_info["file_count"] = file_count models.append(model_info) print(f" + Added as model: {sub_item} (in {item}/{sub_item})") else: print(f" - Skipped (missing required files)") else: print(f" - Skipped (not a directory): {sub_item}") except Exception as e: print(f" Error processing sub-directories in {item_path}: {e}") continue else: print(f" - Skipped (not a directory): {item}") print(f"Total models found: {len(models)}") except Exception as e: print(f"Error listing models: {e}") return models def get_dir_size(self, path: str) -> str: """ 获取目录大小 Args: path: 目录路径 Returns: str: 格式化的大小字符串 """ total_size = 0 for root, dirs, files in os.walk(path): for file in files: file_path = os.path.join(root, file) total_size += os.path.getsize(file_path) # 格式化大小 if total_size < 1024: return f"{total_size}B" elif total_size < 1024 * 1024: return f"{total_size / 1024:.1f}KB" elif total_size < 1024 * 1024 * 1024: return f"{total_size / (1024 * 1024):.1f}MB" else: return f"{total_size / (1024 * 1024 * 1024):.1f}GB" def get_dir_creation_time(self, path: str) -> str: """ 获取目录创建时间 Args: path: 目录路径 Returns: str: 格式化的时间字符串 """ try: # 获取目录创建时间 stat = os.stat(path) # 尝试获取创建时间,不同系统可能有不同的属性 if hasattr(stat, 'st_birthtime'): # macOS creation_time = stat.st_birthtime else: # Linux creation_time = stat.st_mtime # 使用修改时间作为创建时间 return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(creation_time)) except Exception: return "未知" def delete_model(self, model_path: str) -> bool: """ 删除本地模型 Args: model_path: 模型完整路径 Returns: bool: 是否删除成功 """ if not model_path or not os.path.exists(model_path): print(f"[DEBUG] 模型路径不存在: {model_path}") return False try: print(f"[DEBUG] 开始删除模型目录: {model_path}") shutil.rmtree(model_path) print(f"[DEBUG] 模型目录删除成功: {model_path}") return True except Exception as e: print(f"[DEBUG] 删除模型失败: {str(e)}") return False # 测试代码 if __name__ == "__main__": # 直接在这里创建ModelManager实例,避免循环导入 class TestModelManager: """测试用的模型管理器""" def list_models(self): """列出模型""" return [ {"model_id": "test-model-1", "size": "1.2GB", "created_at": "2024-01-01 10:00:00"}, {"model_id": "test-model-2", "size": "800MB", "created_at": "2024-01-02 14:30:00"} ] # 使用测试类 manager = TestModelManager() # 测试列出模型 print("测试列出模型...") models = manager.list_models() for model in models: print(f"模型: {model['model_id']}, 大小: {model['size']}, 创建时间: {model['created_at']}") print("\n注意: 这是一个简化的测试模式,完整功能需要通过app.py运行")