import time
import logging
from contextlib import suppress
from typing import Tuple
import execution
import torch

# 配置日志
logger = logging.getLogger(__name__)


class ExecutionTime:
    """执行性能分析器"""

    def __init__(self, enabled: bool = True):
        self.enabled = enabled
        self.origin_execute = None
        self._is_patched = False

    def get_memory_stats(self) -> Tuple[int, int, int, int]:
        """获取内存统计信息"""
        if not torch.cuda.is_available():
            return 0, 0, 0, 0

        device = torch.device('cuda')
        current_mem = torch.cuda.memory_allocated(device)
        peak_mem = torch.cuda.max_memory_allocated(device)
        reserved_mem = torch.cuda.memory_reserved(device)
        max_reserved = torch.cuda.max_memory_reserved(device)

        return current_mem, peak_mem, reserved_mem, max_reserved

    def reset_peak_memory(self) -> None:
        """重置峰值内存记录"""
        if torch.cuda.is_available():
            torch.cuda.reset_max_memory_allocated(torch.device('cuda'))

    async def profiled_execute(self, server, dynprompt, caches, current_item,
                               extra_data, executed, prompt_id, execution_list,
                               pending_subgraph_results, pending_async_nodes):
        """带性能分析的执行函数"""
        start_time = time.perf_counter()
        self.reset_peak_memory()
        start_mem, _, _, _ = self.get_memory_stats()

        # 执行原始函数
        result = await self.origin_execute(
            server, dynprompt, caches, current_item, extra_data,
            executed, prompt_id, execution_list,
            pending_subgraph_results, pending_async_nodes
        )

        # 计算性能指标
        end_time = time.perf_counter()
        end_mem, peak_mem, reserved_mem, max_reserved = self.get_memory_stats()

        # 记录性能数据'
        unique_id = current_item
        class_type = dynprompt.get_node(unique_id)['class_type']
        time_cost = end_time - start_time
        memory_used = peak_mem - start_mem

        logger.info(
            f"[Profile] id: {unique_id} type: {class_type}, "
            f"time: {time_cost:.5f}s, memory: {memory_used}B, "
            f"peak: {peak_mem}B, reserved: {reserved_mem}B, "
            f"max_reserved: {max_reserved}B"
        )
        return result

    def patch_execution(self) -> bool:
        """修补执行函数以启用性能分析"""
        if not self.enabled or self._is_patched:
            return False

        with suppress(Exception):  # 安全地处理异常
            if hasattr(execution, 'execute'):
                self.origin_execute = execution.execute
                execution.execute = self.profiled_execute
                self._is_patched = True
                logger.info("Execution profiling enabled")
                return True

        return False

    def unpatch_execution(self) -> bool:
        """恢复原始执行函数"""
        if self._is_patched and self.origin_execute:
            execution.execute = self.origin_execute
            self._is_patched = False
            logger.info("Execution profiling disabled")
            return True
        return False
