#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import sys import json import time import uuid import shutil import threading import subprocess from datetime import datetime from flask import Flask, request, jsonify, send_from_directory from flask_cors import CORS from flask_socketio import SocketIO, emit # 确保可以导入自定义模块 sys.path.append(os.path.dirname(os.path.abspath(__file__))) # 导入模型管理模块 from model_manager import ModelManager app = Flask(__name__, static_folder='static', template_folder='templates') app.config['SECRET_KEY'] = 'secret!' CORS(app) socketio = SocketIO(app, cors_allowed_origins="*") # 初始化模型管理器 model_manager = ModelManager() # 任务状态 tasks = {} # 线程存储 task_threads = {} # 模型状态数据库 models_db = {} # 数据库锁,用于避免多线程访问冲突 db_lock = threading.Lock() # 数据库文件路径 db_file = 'models_db.json' def save_models_db(): """保存模型数据库到文件""" global models_db try: print(f"[DEBUG] 开始保存数据库,当前线程: {threading.current_thread().name}") # 不使用锁,避免可能的死锁 # with db_lock: # with open(db_file, 'w') as f: # json.dump(models_db, f, indent=2) # 简化版本,直接写入 with open(db_file, 'w') as f: json.dump(models_db, f, indent=2) print(f"[DEBUG] 模型数据库已保存到 {db_file}") except Exception as e: print(f"[ERROR] 保存模型数据库失败: {str(e)}") @app.route('/') def index(): """提供前端页面""" return send_from_directory('../frontend', 'index.html') @app.route('/api/download', methods=['POST']) def download_model(): """下载模型API""" data = request.json model_id = data.get('model_id') local_path = data.get('local_path', '/home/user/models') task_id = data.get('task_id') if not model_id: return jsonify({'error': '请提供有效的模型ID'}), 400 if not task_id: task_id = f"task_{uuid.uuid4().hex}" # 创建任务 tasks[task_id] = { 'task_id': task_id, 'model_id': model_id, 'local_path': local_path, 'type': 'download', 'status': 'pending', 'progress': 0, 'retry_count': 0, 'start_time': datetime.now().isoformat(), 'message': '准备下载...' } # 将模型添加到数据库或更新状态 model_key = f"{model_id}_{local_path}" # 不使用锁,避免可能的死锁 # with db_lock: if model_key not in models_db: models_db[model_key] = { 'model_id': model_id, 'local_path': os.path.join(local_path, model_id), 'status': 'downloading', 'download_time': None, 'upload_time': None, 'upload_repo_id': None } else: # 更新现有模型的状态为下载中 models_db[model_key]['status'] = 'downloading' models_db[model_key]['download_time'] = None # 清除之前的进度信息 models_db[model_key].pop('progress', None) models_db[model_key].pop('message', None) # 保存数据库 save_models_db() print(f"[DEBUG] 模型状态已更新为下载中: {model_id}") # 启动下载线程 thread = threading.Thread(target=download_models, args=([task_id],)) task_threads[task_id] = thread thread.start() return jsonify({'status': 'success', 'task_id': task_id}), 200 @app.route('/api/upload', methods=['POST']) def upload_model(): """上传模型API""" try: data = request.json model_ids = data.get('model_ids', []) create_repo_flag = data.get('create_repo_flag', True) print(f"[DEBUG] 上传API被调用,模型IDs: {model_ids}, 创建仓库: {create_repo_flag}") if not model_ids: return jsonify({'error': '请选择要上传的模型'}), 400 # 为每个模型创建任务 task_ids = [] for model_id in model_ids: # 查找模型信息 model_info = None for key, model in models_db.items(): if model['model_id'] == model_id and model['status'] in ['downloaded', 'uploading']: model_info = model break if not model_info: print(f"[DEBUG] 模型 {model_id} 未找到或状态不允许上传") continue task_id = f"task_{uuid.uuid4().hex}" tasks[task_id] = { 'task_id': task_id, 'model_id': model_id, 'local_path': model_info['local_path'], 'type': 'upload', 'status': 'pending', 'progress': 0, 'start_time': datetime.now().isoformat(), 'message': '准备上传...' } task_ids.append(task_id) # 更新模型状态 model_info['status'] = 'uploading' # 启动上传线程 thread = threading.Thread(target=upload_models, args=(task_ids, create_repo_flag)) # 为每个任务ID存储同一个线程 for task_id in task_ids: task_threads[task_id] = thread thread.start() return jsonify({'task_ids': task_ids}), 200 except Exception as e: print(f"[DEBUG] 上传API错误: {str(e)}") return jsonify({'error': str(e)}), 500 @app.route('/api/delete', methods=['POST']) def delete_model(): """删除模型API""" print("[DEBUG] 删除API被调用") try: data = request.json print(f"[DEBUG] 删除API请求数据: {data}") model_ids = data.get('model_ids', []) if not model_ids: print("[DEBUG] 未提供模型ID") return jsonify({'error': '请选择要删除的模型'}), 400 deleted_models = [] errors = [] for model_id in model_ids: # 查找模型信息 model_key = None model_info = None for key, model in models_db.items(): if model['model_id'] == model_id: model_key = key model_info = model break if not model_info: errors.append(f"模型 {model_id} 不存在") continue # 检查模型状态 if model_info['status'] in ['downloading', 'uploading']: errors.append(f"模型 {model_id} 正在进行下载/上传操作,无法删除") continue try: # 使用模型管理器删除模型 print(f"[DEBUG] 调用模型管理器删除模型: {model_id}, 路径: {model_info['local_path']}") success = model_manager.delete_model(model_info['local_path']) if success: # 从数据库中删除 with db_lock: if model_key: del models_db[model_key] deleted_models.append(model_id) print(f"[DEBUG] 模型 {model_id} 删除成功") else: errors.append(f"删除模型 {model_id} 失败: 模型管理器返回失败") except Exception as e: errors.append(f"删除模型 {model_id} 失败: {str(e)}") print(f"[ERROR] 删除模型 {model_id} 异常: {str(e)}") result = { 'deleted': deleted_models, 'errors': errors } # 如果有模型被删除,保存数据库 if deleted_models: save_models_db() return jsonify(result), 200 except Exception as e: print(f"[ERROR] 删除API异常: {str(e)}") return jsonify({'error': str(e)}), 500 @app.route('/api/models', methods=['GET']) def get_models(): """获取模型列表""" try: # 获取查询参数 status = request.args.get('status') all_models = request.args.get('all', 'false').lower() == 'true' model_path = request.args.get('path') # 获取用户指定的模型路径 print(f"API get_models called with: status={status}, all={all_models}, path={model_path}") # 从本地文件系统获取模型列表,使用用户指定的路径或默认路径 local_models = model_manager.list_models(local_path=model_path) # 合并本地模型和数据库中的模型信息 for local_model in local_models: # 查找数据库中是否有该模型的信息 model_key = f"{local_model['id']}_{os.path.dirname(local_model['path'])}" if model_key in models_db: # 更新本地模型的状态信息 db_model = models_db[model_key] # 保留数据库中的状态,不覆盖进行中的任务状态 local_model['status'] = db_model.get('status', 'downloaded') local_model['download_time'] = db_model.get('download_time') local_model['upload_time'] = db_model.get('upload_time') local_model['upload_repo_id'] = db_model.get('upload_repo_id') # 添加进度信息用于前端显示 if db_model.get('status') in ['downloading', 'uploading']: local_model['progress'] = db_model.get('progress', 0) local_model['message'] = db_model.get('message', '进行中...') print(f"[DEBUG] 从数据库加载模型: {local_model['id']}, 状态: {local_model['status']}") else: # 如果模型不在数据库中,添加到数据库 models_db[model_key] = { 'model_id': local_model['id'], 'local_path': local_model['path'], 'status': 'downloaded', 'download_time': datetime.now().isoformat(), 'upload_time': None, 'upload_repo_id': None } print(f"[DEBUG] 新增模型到数据库: {local_model['id']}") # 同时更新本地模型的状态信息 local_model['status'] = 'downloaded' local_model['download_time'] = models_db[model_key]['download_time'] local_model['upload_time'] = None local_model['upload_repo_id'] = None # 及时保存数据库到文件 save_models_db() # 根据状态筛选 if status == 'downloaded': # 返回已下载但未上传的模型 models = [model for model in local_models if model.get('status') == 'downloaded'] else: # 返回所有模型 models = local_models # 格式化返回数据 result = { 'models': models } return jsonify(result), 200 except Exception as e: print(f"Error getting models: {e}") return jsonify({'error': str(e)}), 500 @app.route('/api/task/', methods=['GET']) def get_task_status(task_id): """获取任务状态""" if task_id not in tasks: return jsonify({'error': '任务不存在'}), 404 return jsonify(tasks[task_id]), 200 @app.route('/api/system/info', methods=['GET']) def get_system_info(): """获取系统信息""" try: # 获取操作系统信息 os_info = subprocess.check_output(['uname', '-a']).decode('utf-8').strip() # 获取磁盘使用情况 disk_usage = subprocess.check_output(['df', '-h']).decode('utf-8') # 获取内存使用情况 memory_usage = subprocess.check_output(['free', '-h']).decode('utf-8') return jsonify({ 'os': os_info, 'disk_usage': disk_usage, 'memory_usage': memory_usage }), 200 except Exception as e: return jsonify({'error': str(e)}), 500 @socketio.on('connect') def handle_connect(): """处理WebSocket连接""" print('客户端已连接') @socketio.on('disconnect') def handle_disconnect(): """处理WebSocket断开连接""" pass # 静默处理断开连接,减少日志输出 @app.route('/api/download/cancel/', methods=['POST']) def cancel_download(task_id): """取消下载任务""" try: print(f"[DEBUG] 收到取消下载请求: {task_id}") # 检查任务是否存在 if task_id in tasks: # 从任务字典中移除任务 task = tasks.pop(task_id) print(f"[DEBUG] 已取消下载任务: {task_id}, 模型ID: {task.get('model_id')}") # 更新模型状态 model_id = task.get('model_id') local_path = task.get('local_path') if model_id and local_path: model_key = f"{model_id}_{local_path}" if model_key in models_db: # 将状态从downloading改为failed if models_db[model_key].get('status') == 'downloading': models_db[model_key]['status'] = 'failed' models_db[model_key].pop('progress', None) models_db[model_key].pop('message', None) save_models_db() print(f"[DEBUG] 已更新模型状态为失败: {model_id}") # 从线程存储中移除 if task_id in task_threads: del task_threads[task_id] return jsonify({'success': True, 'message': '下载任务已取消'}) else: print(f"[DEBUG] 任务不存在: {task_id}") return jsonify({'success': False, 'message': '任务不存在'}), 404 except Exception as e: print(f"[DEBUG] 取消下载失败: {str(e)}") return jsonify({'success': False, 'message': f'取消失败: {str(e)}'}), 500 @app.route('/api/upload/cancel/', methods=['POST']) def cancel_upload(task_id): """取消上传任务""" try: print(f"[DEBUG] 收到取消上传请求: {task_id}") # 检查任务是否存在 if task_id in tasks: # 从任务字典中移除任务 task = tasks.pop(task_id) print(f"[DEBUG] 已取消上传任务: {task_id}, 模型ID: {task.get('model_id')}") # 更新模型状态 model_id = task.get('model_id') local_path = task.get('local_path') if model_id and local_path: model_key = f"{model_id}_{local_path}" if model_key in models_db: # 将状态从uploading改为downloaded if models_db[model_key].get('status') == 'uploading': models_db[model_key]['status'] = 'downloaded' models_db[model_key].pop('progress', None) models_db[model_key].pop('message', None) save_models_db() print(f"[DEBUG] 已更新模型状态为已下载: {model_id}") return jsonify({'success': True, 'message': '上传任务已取消'}) else: print(f"[DEBUG] 任务不存在: {task_id}") return jsonify({'success': False, 'message': '任务不存在'}), 404 except Exception as e: print(f"[DEBUG] 取消上传失败: {str(e)}") return jsonify({'success': False, 'message': f'取消失败: {str(e)}'}), 500 @socketio.on('subscribe_task') def handle_subscribe_task(data): """订阅任务进度更新""" task_id = data.get('task_id') if task_id in tasks: # 立即发送当前状态 emit('task_update', tasks[task_id]) def download_models(task_ids): """下载模型线程""" import concurrent.futures for task_id in task_ids: task = tasks.get(task_id) if not task: continue model_id = task['model_id'] local_path = task['local_path'] # 更新任务状态 task['status'] = 'downloading' task['message'] = f'开始下载模型 {model_id}' socketio.emit('task_update', task) # 尝试下载模型 max_retries = 10 retry_count = 0 # 使用默认参数捕获task_id,避免lambda闭包问题 def make_progress_callback(tid): def progress_callback(progress, detail): update_download_progress(tid, progress, detail) return progress_callback while retry_count < max_retries: # 检查任务是否已被取消 if task_id not in tasks: print(f"[INFO] 下载任务已取消: {task_id}") return try: # 创建取消检查函数 def cancel_check(): return task_id not in tasks # 使用线程池执行下载,以便可以取消 with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: # 提交下载任务 future = executor.submit( model_manager.download_model, model_id=model_id, local_path=local_path, progress_callback=make_progress_callback(task_id), cancel_check=cancel_check ) # 等待下载完成,同时检查任务是否被取消 while not future.done(): # 检查任务是否已被取消 if task_id not in tasks: print(f"[INFO] 下载任务已取消: {task_id}") # 取消future future.cancel() return # 短暂睡眠,避免CPU占用过高 time.sleep(0.1) # 获取下载结果 model_path = future.result() # 下载成功 task['status'] = 'completed' task['progress'] = 100 task['message'] = f'模型 {model_id} 下载完成' # 发送下载完成事件 socketio.emit('download_complete', { 'taskId': task_id, 'modelId': model_id }) socketio.emit('task_update', task) # 更新模型状态 model_key = f"{model_id}_{local_path}" if model_key in models_db: # 只有当状态不是已上传时才更新为已下载 if models_db[model_key].get('status') != 'uploaded': models_db[model_key]['status'] = 'downloaded' models_db[model_key]['download_time'] = datetime.now().isoformat() # 清除进度信息 models_db[model_key].pop('progress', None) models_db[model_key].pop('message', None) break except Exception as e: # 检查任务是否已被取消 if task_id not in tasks: print(f"[INFO] 下载任务已取消: {task_id}") return retry_count += 1 task['retry_count'] = retry_count task['message'] = f'下载失败 (尝试 {retry_count}/{max_retries}): {str(e)}' socketio.emit('task_update', task) if retry_count < max_retries: # 等待一段时间后重试 time.sleep(5) else: # 达到最大重试次数 task['status'] = 'failed' task['message'] = f'达到最大重试次数,下载失败: {str(e)}' # 发送下载失败事件 socketio.emit('download_failed', { 'taskId': task_id, 'modelId': model_id, 'error': str(e) }) socketio.emit('task_update', task) # 更新模型状态 model_key = f"{model_id}_{local_path}" if model_key in models_db: # 只有当状态不是已上传时才更新为失败 if models_db[model_key].get('status') != 'uploaded': models_db[model_key]['status'] = 'failed' models_db[model_key]['download_time'] = None # 清除进度信息 models_db[model_key].pop('progress', None) models_db[model_key].pop('message', None) def update_download_progress(task_id, progress, detail=None): """更新下载进度""" if task_id not in tasks: return task = tasks[task_id] task['progress'] = progress if detail: if isinstance(detail, dict): # 新的进度回调格式 task['message'] = f"正在下载: {detail.get('current_file', 'unknown')} ({detail.get('file_count', 0)}/{detail.get('total_files', 0)})" # 通过WebSocket发送进度更新 socketio.emit('download_progress', { 'taskId': task_id, 'progress': progress, 'fileCount': detail.get('file_count', 0), 'totalFiles': detail.get('total_files', 0), 'currentFile': detail.get('current_file', 'unknown'), 'fileSize': detail.get('file_size', 0) }) else: # 旧的进度回调格式 task['message'] = detail socketio.emit('download_progress', { 'taskId': task_id, 'progress': progress, 'message': detail }) # 发送通用任务更新 socketio.emit('task_update', task) def upload_models(task_ids, create_repo_flag=True): """上传模型线程""" for task_id in task_ids: task = tasks.get(task_id) if not task: continue model_id = task['model_id'] local_path = task['local_path'] # 更新任务状态 task['status'] = 'uploading' task['message'] = f'开始上传模型 {model_id}' socketio.emit('task_update', task) # 使用默认参数捕获task_id,避免lambda闭包问题 def make_upload_progress_callback(tid): def progress_callback(progress, detail): update_upload_progress(tid, progress, detail) return progress_callback try: # 检查任务是否已被取消 if task_id not in tasks: print(f"[INFO] 上传任务已取消: {task_id}") return # 调用模型管理器上传模型 repo_id = os.path.basename(model_id) model_manager.upload_model( local_path=local_path, repo_id=repo_id, create_repo_flag=create_repo_flag, progress_callback=make_upload_progress_callback(task_id) ) # 上传成功 task['status'] = 'completed' task['progress'] = 100 task['message'] = f'模型 {model_id} 上传完成' socketio.emit('task_update', task) # 更新模型状态 for key, model in models_db.items(): if model['model_id'] == model_id: model['status'] = 'uploaded' model['upload_time'] = datetime.now().isoformat() model['upload_repo_id'] = repo_id break except Exception as e: # 检查任务是否已被取消 if task_id not in tasks: print(f"[INFO] 上传任务已取消: {task_id}") return task['status'] = 'failed' task['message'] = f'上传失败: {str(e)}' socketio.emit('task_update', task) # 更新模型状态 for key, model in models_db.items(): if model['model_id'] == model_id: model['status'] = 'downloaded' # 恢复为已下载状态 break def update_upload_progress(task_id, progress, detail=None): """更新上传进度""" if task_id not in tasks: return task = tasks[task_id] task['progress'] = progress if detail: task['message'] = detail socketio.emit('task_update', task) if __name__ == '__main__': # 加载模型数据库 db_file = 'models_db.json' if os.path.exists(db_file): try: with open(db_file, 'r') as f: models_db = json.load(f) except Exception as e: print(f"加载模型数据库失败: {str(e)}") # 清理数据库中状态为 downloading 或 uploading 的任务(可能是之前未正常关闭的任务) for model_key, model_info in models_db.items(): if model_info.get('status') in ['downloading', 'uploading']: print(f"[INFO] 清理异常状态模型: {model_info.get('model_id')}, 状态: {model_info.get('status')}") model_info['status'] = 'failed' model_info.pop('progress', None) model_info.pop('message', None) save_models_db() # 启动服务器 socketio.run(app, host='0.0.0.0', port=2026, debug=False) # 保存模型数据库 try: with open(db_file, 'w') as f: json.dump(models_db, f, indent=2) except Exception as e: print(f"保存模型数据库失败: {str(e)}")