# scripts/monitor.py import subprocess import threading import time from datetime import datetime from typing import Optional class MemoryMonitor: """显存使用监控器""" def __init__(self, device_id: int, log_file: str = "memory_simple.log"): self.device_id = device_id self.log_file = log_file self.monitoring = False self.monitor_thread: Optional[threading.Thread] = None self.total_memory: Optional[int] = None def get_total_memory(self) -> Optional[int]: """获取总显存""" try: cmd = f"hy-smi -d {self.device_id} --showmeminfo vram --showuse" result = subprocess.run(cmd, shell=True, capture_output=True, text=True) for line in result.stdout.split('\n'): if "vram Total Memory" in line: parts = line.split(':') if len(parts) >= 3: memory_str = parts[2].strip().split()[0] return int(memory_str) except Exception as e: print(f"获取总显存失败: {e}") return None def monitor_memory(self): """监控显存使用""" with open(self.log_file, 'w') as f: f.write("") # 清空文件 while self.monitoring: try: timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") cmd = f"hy-smi -d {self.device_id} --showmeminfo vram --showuse" result = subprocess.run(cmd, shell=True, capture_output=True, text=True) used_memory = None used_percent = None for line in result.stdout.split('\n'): if "vram Total Used Memory" in line: parts = line.split(':') if len(parts) >= 3: used_memory = parts[2].strip().split()[0] elif "HCU use" in line: parts = line.split(':') if len(parts) >= 3: used_percent = parts[2].strip().split()[0] if used_memory and used_percent: with open(self.log_file, 'a') as f: f.write(f"{timestamp} {used_memory} {used_percent}\n") except Exception as e: print(f"监控出错: {e}") time.sleep(1) def start(self): """开始监控""" self.total_memory = self.get_total_memory() if self.total_memory: print(f"总显存: {self.total_memory} MiB") print("开始监控显存使用...") self.monitoring = True self.monitor_thread = threading.Thread(target=self.monitor_memory) self.monitor_thread.start() return self.total_memory def stop(self): """停止监控""" self.monitoring = False if self.monitor_thread: self.monitor_thread.join(timeout=2) def get_statistics(self) -> dict[str, any]: """获取统计信息""" stats = { "total_memory": self.total_memory, "max_used": 0, "max_percent": 0 } try: with open(self.log_file, 'r') as f: lines = f.readlines() if lines: # 提取最大使用量和最大使用率 used_values = [float(line.split()[2]) for line in lines if len(line.split()) >= 3] percent_values = [float(line.split()[3]) for line in lines if len(line.split()) >= 4] if used_values: stats["max_used"] = max(used_values) if percent_values: stats["max_percent"] = max(percent_values) except Exception as e: print(f"读取监控日志失败: {e}") return stats