# scripts/model_runner.py import os import yaml from typing import Dict, Any, List from .base_runner import BaseRunner from .monitor import MemoryMonitor class ModelRunner(BaseRunner): """模型运行器""" def __init__(self, config: Dict[str, Any], env): super().__init__(config, env) def load_model_config(self, model_name: str) -> Dict[str, Any]: """加载模型配置""" config_file = f"config/{model_name}.yaml" if not os.path.exists(config_file): raise FileNotFoundError(f"模型配置文件不存在: {config_file}") with open(config_file, 'r') as f: return yaml.safe_load(f) def run_single_model(self, model_name: str): """运行单个模型""" print(f"\n{'='*60}") print(f"开始测试模型: {model_name}") print(f"{'='*60}") # 加载模型配置 try: model_config = self.load_model_config(model_name) except FileNotFoundError as e: print(f"错误: {e}") return # 设置环境变量 self.env.setup(model_config.get('model', {}).get('env_vars', {})) # 创建监控器 monitor = MemoryMonitor( device_id=self.env.device_id, log_file=self.base_config.get('monitor', {}).get('log_file', 'memory_simple.log') ) # 获取结果目录 result_dir = self.env.get_result_dir() # 获取模型配置 model_info = model_config.get('model', {}) # 检查是单个模型文件还是多个模型文件 if 'model_files' in model_info: # 多个模型文件的情况(如YOLOv3) self.run_multiple_model_files(model_info, monitor, result_dir) else: # 单个模型文件的情况 self.run_single_model_file(model_info, monitor, result_dir) print(f"\n✓ {model_name} 测试完成!") def run_single_model_file(self, model_info: Dict[str, Any], monitor: MemoryMonitor, result_dir: str): """运行单个模型文件(多个batch size)""" model_file = model_info.get('model_file') if not model_file or not os.path.exists(model_file): print(f"错误: 模型文件不存在: {model_file}") return model_name = os.path.basename(model_file) batch_sizes = model_info.get('batch_sizes', [1, 8]) print(f"模型文件: {model_file}") print(f"测试batch大小: {batch_sizes}") print(f"{'-'*60}") for batch in batch_sizes: print(f"\n正在测试 batch={batch} ...") print(f"{'-'*40}") # 开始监控 total_memory = monitor.start() # 构建命令 cmd = self.build_command(model_file, model_info, batch) # 生成日志文件名 log_file = os.path.join(result_dir, f"{model_name}-{batch}batch.log") # 运行模型 success, output = self.run_model(cmd, log_file) # 停止监控 monitor.stop() # 获取统计信息 stats = monitor.get_statistics() # 输出统计信息 print(f"\n=== 显存使用统计 ===") print(f"最大使用: {stats['max_used']:.2f} MiB") print(f"总显存: {total_memory} MiB") print(f"峰值使用率: {stats['max_percent']:.2f}%") # 将统计信息追加到日志文件 with open(log_file, 'a') as f: f.write(f"\n=== 统计摘要 ===\n") f.write(f"最大使用: {stats['max_used']:.2f} MiB\n") f.write(f"总显存: {total_memory} MiB\n") f.write(f"峰值使用率: {stats['max_percent']:.2f}%\n") if success: print(f"✓ batch={batch} 测试完成,日志保存至: {log_file}") else: print(f"✗ batch={batch} 测试失败!") print(f"{'-'*40}") def run_multiple_model_files(self, model_info: Dict[str, Any], monitor: MemoryMonitor, result_dir: str): """运行多个模型文件(如YOLOv3不同batch size有不同文件)""" model_files = model_info.get('model_files', []) if not model_files: print("错误: 没有找到模型文件配置") return print(f"测试多个模型文件...") print(f"{'-'*60}") for model_file_info in model_files: model_file = model_file_info.get('path') batch = model_file_info.get('batch', 1) if not model_file or not os.path.exists(model_file): print(f"警告: 模型文件不存在,跳过: {model_file}") continue print(f"\n正在测试 batch={batch} ...") print(f"模型文件: {model_file}") print(f"{'-'*40}") # 开始监控 total_memory = monitor.start() # 构建命令 cmd = self.build_command(model_file, model_info, batch) # 生成日志文件名 model_name = os.path.basename(model_file) log_file = os.path.join(result_dir, f"{model_name}-{batch}batch.log") # 运行模型 success, output = self.run_model(cmd, log_file) # 停止监控 monitor.stop() # 获取统计信息 stats = monitor.get_statistics() # 输出统计信息 print(f"\n=== 显存使用统计 ===") print(f"最大使用: {stats['max_used']:.2f} MiB") print(f"总显存: {total_memory} MiB") print(f"峰值使用率: {stats['max_percent']:.2f}%") # 将统计信息追加到日志文件 with open(log_file, 'a') as f: f.write(f"\n=== 统计摘要 ===\n") f.write(f"最大使用: {stats['max_used']:.2f} MiB\n") f.write(f"总显存: {total_memory} MiB\n") f.write(f"峰值使用率: {stats['max_percent']:.2f}%\n") if success: print(f"✓ batch={batch} 测试完成,日志保存至: {log_file}") else: print(f"✗ batch={batch} 测试失败!") print(f"{'-'*40}") def run_all_models(self): """运行所有模型""" models_to_run = self.base_config.get('models_to_run', []) if not models_to_run: print("错误: 没有配置要运行的模型") return print(f"将运行以下模型: {models_to_run}") for model_name in models_to_run: try: self.run_single_model(model_name) except Exception as e: print(f"运行模型 {model_name} 时出错: {e}") continue print(f"\n{'='*60}") print("所有测试完成!") print(f"{'='*60}")