import functools
import inspect
import time
import types

import numpy as np
from prettytable import PrettyTable
import torch


class AutoTimer:
    def __init__(self):
        self.key2durtions = {}
        self.key2start = {}
        self._targets = {}
        self._proxy_classes = {}
        self._original_refs = {}
        self._wrappers = {}
        self.start_record = False

    def start(self, key):
        if not self.start_record:
            return
        torch.cuda.synchronize()
        self.key2start[key] = time.time()

    def end(self, key):
        if not self.start_record:
            return
        start_time = self.key2start.get(key, None)
        if start_time is None:
            raise RuntimeError(f"{key} is not started")
        self.key2start[key] = None
        if key not in self.key2durtions:
            self.key2durtions[key] = []

        torch.cuda.synchronize()
        self.key2durtions[key].append((time.time() - start_time) * 1000)

    def summary(self, batchsize=1):
        if not self.start_record or len(self.key2durtions) == 0:
            print("Nothing to summary!")
            return

        pt = PrettyTable()
        pt.title = "Test Latency"
        pt.field_names = ['模块', '运行次数', '总耗时(ms)', '最长耗时(ms)', 
                          '最短耗时(ms)', '平均耗时(ms)', '平均性能(fps)']

        for key, durations in self.key2durtions.items():
            exec_num = len(durations)
            time_arr = np.array(durations)
            sum_dur = round(time_arr.sum().item(), 2)
            max_dur = round(time_arr.max().item(), 2)
            min_dur = round(time_arr.min().item(), 2)
            avg_dur = round(time_arr.mean().item(), 2)
            avg_perf = round(batchsize * 1000 / avg_dur, 2)
            pt.add_row([key, exec_num, sum_dur, max_dur, min_dur, avg_dur, 
                        avg_perf])
        
        print(pt)
        return pt

    def print_latency(self, mode='avg', groups=None):
        if not self.start_record or len(self.key2durtions) == 0:
            print("Nothing to print!")
            return
        assert mode in ['avg', 'sum'], 'mode must be avg or sum'

        key2res = {}
        for key, durations in self.key2durtions.items():
            key2res[key] = sum(durations)
            if mode == 'avg':
                key2res[key] /= len(durations)
        
        done_keys = set()
        if groups is not None:
            for group_name, keys in groups.items():
                total = 0
                for key in keys:
                    if key not in key2res:
                        raise RuntimeError(f"key {key} not found!")
                    if key in done_keys:
                        raise RuntimeError(f"key {key} is already counted!")
                    total += key2res[key]
                    done_keys.add(key)
                print(f"{group_name} time spent: {round(total)} ms")
        
        _mode = 'average' if mode == 'avg' else 'sum'
        for key, avg in key2res.items():
            if key in done_keys:
                continue
            print(f"{key} {_mode} time spent: {round(avg)} ms")

    def clear(self):
        self.key2durtions.clear()
        self.key2start.clear()
        # self._targets.clear()
        # self._proxy_classes.clear()

    def start_work(self):
        self.start_record = True

    def add_target(self, target, key=None):
        identifier = self._get_identifier(target) if key is None else key
        if identifier in self._targets:
            return
        
        if not callable(target):
            raise TypeError("Only support callable object!")
        
        self._targets[identifier] = target

        @functools.wraps(target)
        def wrapper(*args, **kwargs):
            self.start(identifier)
            # print(f"-----------------> {identifier}...")
            result = target(*args, **kwargs)
            self.end(identifier)
            return result
        
        if self._is_wrapped_target(target):
            self._wrap_wrapped_target(target, identifier)
        elif isinstance(target, types.MethodType):
            instance = target.__self__
            method_name = target.__name__
            if method_name == '__call__':
                self._wrap_callable_object(instance, identifier)
            else:
                setattr(instance, method_name, wrapper)
        elif isinstance(target, types.FunctionType):
            # globals()[target.__name__] = wrapper
            caller_globals = inspect.currentframe().f_back.f_globals
            caller_globals[target.__name__] = wrapper
            # print(wrapper)
            
        elif callable(target) and hasattr(target, '__call__'):
            self._wrap_callable_object(target, identifier)
        else:
            raise ValueError(f"Unsupported target: {type(target)}")

    def add_targets(self, targets):
        for target in targets:
            if isinstance(target, (list, tuple)):
                assert len(target) == 2
                target, key = target
                # print(key)
                self.add_target(target, key=key)
            else:
                self.add_target(target, key=None)
    
    def show_targets(self):
        print("Count methods or functions:")
        for i, (key, value) in enumerate(self._targets.items()):
            # print(f"{key}: {value}")
            print(i, key)

    def _get_identifier(self, target):
        if isinstance(target, types.MethodType):
            return f"{target.__self__.__class__.__name__}.{target.__name__}"
        elif isinstance(target, types.FunctionType):
            return f"{target.__module__}.{target.__name__}"
        elif callable(target) and hasattr(target, '__call__'):
            return f"{target.__class__.__name__}.__call__"
        else:
            raise ValueError(f"{target} is not a function or method!")

    def _is_wrapped_target(self, target):
        if hasattr(target, '__wrapped__'):
            return True
        
        if isinstance(target, types.MethodType):
            if target.__self__.__class__.__name__ == 'type':  # 类方法
                return target.__self__.__name__ not in target.__qualname__
            else:
                return target.__self__.__class__.__name__ not in \
                    target.__qualname__
        elif isinstance(target, types.FunctionType):
            return target.__name__ != target.__qualname__
        elif callable(target) and hasattr(target, '__call__'):
            return target.__class__.__name__ not in target.__call__.__qualname__
        else:
            raise ValueError(f"{target} is not a function or method!")

    def _get_original_function(self, func):
        if hasattr(func, '__wrapped__'):
            return func.__wrapped__
        
        if hasattr(func, '__closure__') and func.__closure__:
            for cell in func.__closure__:
                if callable(cell.cell_contents):
                    return cell.cell_contents
        
        return None
    
    def _wrap_wrapped_target(self, target, identifier):
        if not isinstance(target, (types.FunctionType, types.MethodType)):
            if callable(target) and hasattr(target, '__call__'):
                self._wrap_callable_object(target, identifier, )
                return
            
        orig_func = self._get_original_function(target)
        # print("orig_func", orig_func)
        assert orig_func is not None
        self.add_target(orig_func, identifier)
        
        @functools.wraps(target)
        def wrapper(*args, **kwargs):
            self.start(identifier)
            # print(f"-----------------> {identifier}...")
            result = target(*args, **kwargs)
            self.end(identifier)
            return result

        # @functools.wraps(orig_func)
        # def wrapper(*args, **kwargs):
        #     self.start(identifier)
        #     print(f"-----------------> {identifier}...")
        #     result = target(*args, **kwargs)
        #     self.end(identifier)
        #     return result

        self._original_refs[identifier] = target
        self._wrappers[identifier] = wrapper
        
        if isinstance(target, types.FunctionType):
            # globals()[orig_func.__name__] = wrapper
            caller_globals = inspect.currentframe().f_back.f_globals
            caller_globals[orig_func.__name__] = wrapper
            # caller_globals[target.__name__] = wrapper
        elif isinstance(target, types.MethodType):
            setattr(target.__self__, orig_func.__name__, wrapper)
        else:
            raise TypeError("Only support function or method!")

    def _wrap_callable_object(self, obj, identifier, decorator=None):
        original_class = obj.__class__

        if decorator is not None:
            class ProxyClass(original_class):
                @decorator
                def __call__(proxy_self, *args, **kwargs):
                    self.start(identifier)
                    # print(f"-----------------> {identifier}...")
                    result = super(
                        ProxyClass, proxy_self).__call__(*args, **kwargs)
                    self.end(identifier)
                    return result
        else:
            class ProxyClass(original_class):
                def __call__(proxy_self, *args, **kwargs):
                    self.start(identifier)
                    # print(f"-----------------> {identifier}...")
                    result = super(
                        ProxyClass, proxy_self).__call__(*args, **kwargs)
                    self.end(identifier)
                    return result
            
        setattr(ProxyClass, "__name__", "Proxy" + original_class.__name__)
        obj.__class__ = ProxyClass


default_timer = AutoTimer()
