# scripts/monitor.py
import subprocess
import threading
import time
from datetime import datetime
from typing import Optional

class MemoryMonitor:
    """显存使用监控器"""
    
    def __init__(self, device_id: int, log_file: str = "memory_simple.log"):
        self.device_id = device_id
        self.log_file = log_file
        self.monitoring = False
        self.monitor_thread: Optional[threading.Thread] = None
        self.total_memory: Optional[int] = None
        
    def get_total_memory(self) -> Optional[int]:
        """获取总显存"""
        try:
            cmd = f"hy-smi -d {self.device_id} --showmeminfo vram --showuse"
            result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
            
            for line in result.stdout.split('\n'):
                if "vram Total Memory" in line:
                    parts = line.split(':')
                    if len(parts) >= 3:
                        memory_str = parts[2].strip().split()[0]
                        return int(memory_str)
        except Exception as e:
            print(f"获取总显存失败: {e}")
            
        return None
    
    def monitor_memory(self):
        """监控显存使用"""
        with open(self.log_file, 'w') as f:
            f.write("")  # 清空文件
            
        while self.monitoring:
            try:
                timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                cmd = f"hy-smi -d {self.device_id} --showmeminfo vram --showuse"
                result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
                
                used_memory = None
                used_percent = None
                
                for line in result.stdout.split('\n'):
                    if "vram Total Used Memory" in line:
                        parts = line.split(':')
                        if len(parts) >= 3:
                            used_memory = parts[2].strip().split()[0]
                    elif "HCU use" in line:
                        parts = line.split(':')
                        if len(parts) >= 3:
                            used_percent = parts[2].strip().split()[0]
                
                if used_memory and used_percent:
                    with open(self.log_file, 'a') as f:
                        f.write(f"{timestamp} {used_memory} {used_percent}\n")
                        
            except Exception as e:
                print(f"监控出错: {e}")
                
            time.sleep(1)
    
    def start(self):
        """开始监控"""
        self.total_memory = self.get_total_memory()
        
        if self.total_memory:
            print(f"总显存: {self.total_memory} MiB")
        
        print("开始监控显存使用...")
        
        self.monitoring = True
        self.monitor_thread = threading.Thread(target=self.monitor_memory)
        self.monitor_thread.start()
        
        return self.total_memory
    
    def stop(self):
        """停止监控"""
        self.monitoring = False
        if self.monitor_thread:
            self.monitor_thread.join(timeout=2)
    
    def get_statistics(self) -> dict[str, any]:
        """获取统计信息"""
        stats = {
            "total_memory": self.total_memory,
            "max_used": 0,
            "max_percent": 0
        }
        
        try:
            with open(self.log_file, 'r') as f:
                lines = f.readlines()
                
            if lines:
                # 提取最大使用量和最大使用率
                used_values = [float(line.split()[2]) for line in lines if len(line.split()) >= 3]
                percent_values = [float(line.split()[3]) for line in lines if len(line.split()) >= 4]
                
                if used_values:
                    stats["max_used"] = max(used_values)
                if percent_values:
                    stats["max_percent"] = max(percent_values)
                    
        except Exception as e:
            print(f"读取监控日志失败: {e}")
            
        return stats