# scripts/model_runner.py
import os
import yaml
from typing import Dict, Any, List
from .base_runner import BaseRunner
from .monitor import MemoryMonitor

class ModelRunner(BaseRunner):
    """模型运行器"""
    
    def __init__(self, config: Dict[str, Any], env):
        super().__init__(config, env)
        
    def load_model_config(self, model_name: str) -> Dict[str, Any]:
        """加载模型配置"""
        config_file = f"config/{model_name}.yaml"
        
        if not os.path.exists(config_file):
            raise FileNotFoundError(f"模型配置文件不存在: {config_file}")
        
        with open(config_file, 'r') as f:
            return yaml.safe_load(f)
    
    def run_single_model(self, model_name: str):
        """运行单个模型"""
        
        print(f"\n{'='*60}")
        print(f"开始测试模型: {model_name}")
        print(f"{'='*60}")
        
        # 加载模型配置
        try:
            model_config = self.load_model_config(model_name)
        except FileNotFoundError as e:
            print(f"错误: {e}")
            return
        
        # 设置环境变量
        self.env.setup(model_config.get('model', {}).get('env_vars', {}))
        
        # 创建监控器
        monitor = MemoryMonitor(
            device_id=self.env.device_id,
            log_file=self.base_config.get('monitor', {}).get('log_file', 'memory_simple.log')
        )
        
        # 获取结果目录
        result_dir = self.env.get_result_dir()
        
        # 获取模型配置
        model_info = model_config.get('model', {})
        
        # 检查是单个模型文件还是多个模型文件
        if 'model_files' in model_info:
            # 多个模型文件的情况（如YOLOv3）
            self.run_multiple_model_files(model_info, monitor, result_dir)
        else:
            # 单个模型文件的情况
            self.run_single_model_file(model_info, monitor, result_dir)
        
        print(f"\n✓ {model_name} 测试完成!")
    
    def run_single_model_file(self, model_info: Dict[str, Any], monitor: MemoryMonitor, result_dir: str):
        """运行单个模型文件（多个batch size）"""
        
        model_file = model_info.get('model_file')
        if not model_file or not os.path.exists(model_file):
            print(f"错误: 模型文件不存在: {model_file}")
            return
        
        model_name = os.path.basename(model_file)
        batch_sizes = model_info.get('batch_sizes', [1, 8])
        
        print(f"模型文件: {model_file}")
        print(f"测试batch大小: {batch_sizes}")
        print(f"{'-'*60}")
        
        for batch in batch_sizes:
            print(f"\n正在测试 batch={batch} ...")
            print(f"{'-'*40}")
            
            # 开始监控
            total_memory = monitor.start()
            
            # 构建命令
            cmd = self.build_command(model_file, model_info, batch)
            
            # 生成日志文件名
            log_file = os.path.join(result_dir, f"{model_name}-{batch}batch.log")
            
            # 运行模型
            success, output = self.run_model(cmd, log_file)
            
            # 停止监控
            monitor.stop()
            
            # 获取统计信息
            stats = monitor.get_statistics()
            
            # 输出统计信息
            print(f"\n=== 显存使用统计 ===")
            print(f"最大使用: {stats['max_used']:.2f} MiB")
            print(f"总显存: {total_memory} MiB")
            print(f"峰值使用率: {stats['max_percent']:.2f}%")
            
            # 将统计信息追加到日志文件
            with open(log_file, 'a') as f:
                f.write(f"\n=== 统计摘要 ===\n")
                f.write(f"最大使用: {stats['max_used']:.2f} MiB\n")
                f.write(f"总显存: {total_memory} MiB\n")
                f.write(f"峰值使用率: {stats['max_percent']:.2f}%\n")
            
            if success:
                print(f"✓ batch={batch} 测试完成，日志保存至: {log_file}")
            else:
                print(f"✗ batch={batch} 测试失败!")
            
            print(f"{'-'*40}")
    
    def run_multiple_model_files(self, model_info: Dict[str, Any], monitor: MemoryMonitor, result_dir: str):
        """运行多个模型文件（如YOLOv3不同batch size有不同文件）"""
        
        model_files = model_info.get('model_files', [])
        if not model_files:
            print("错误: 没有找到模型文件配置")
            return
        
        print(f"测试多个模型文件...")
        print(f"{'-'*60}")
        
        for model_file_info in model_files:
            model_file = model_file_info.get('path')
            batch = model_file_info.get('batch', 1)
            
            if not model_file or not os.path.exists(model_file):
                print(f"警告: 模型文件不存在，跳过: {model_file}")
                continue
            
            print(f"\n正在测试 batch={batch} ...")
            print(f"模型文件: {model_file}")
            print(f"{'-'*40}")
            
            # 开始监控
            total_memory = monitor.start()
            
            # 构建命令
            cmd = self.build_command(model_file, model_info, batch)
            
            # 生成日志文件名
            model_name = os.path.basename(model_file)
            log_file = os.path.join(result_dir, f"{model_name}-{batch}batch.log")
            
            # 运行模型
            success, output = self.run_model(cmd, log_file)
            
            # 停止监控
            monitor.stop()
            
            # 获取统计信息
            stats = monitor.get_statistics()
            
            # 输出统计信息
            print(f"\n=== 显存使用统计 ===")
            print(f"最大使用: {stats['max_used']:.2f} MiB")
            print(f"总显存: {total_memory} MiB")
            print(f"峰值使用率: {stats['max_percent']:.2f}%")
            
            # 将统计信息追加到日志文件
            with open(log_file, 'a') as f:
                f.write(f"\n=== 统计摘要 ===\n")
                f.write(f"最大使用: {stats['max_used']:.2f} MiB\n")
                f.write(f"总显存: {total_memory} MiB\n")
                f.write(f"峰值使用率: {stats['max_percent']:.2f}%\n")
            
            if success:
                print(f"✓ batch={batch} 测试完成，日志保存至: {log_file}")
            else:
                print(f"✗ batch={batch} 测试失败!")
            
            print(f"{'-'*40}")
    
    def run_all_models(self):
        """运行所有模型"""
        
        models_to_run = self.base_config.get('models_to_run', [])
        
        if not models_to_run:
            print("错误: 没有配置要运行的模型")
            return
        
        print(f"将运行以下模型: {models_to_run}")
        
        for model_name in models_to_run:
            try:
                self.run_single_model(model_name)
            except Exception as e:
                print(f"运行模型 {model_name} 时出错: {e}")
                continue
        
        print(f"\n{'='*60}")
        print("所有测试完成!")
        print(f"{'='*60}")