      
"""
onnx modifier: provide a conviennt way to modify onnx model
1. add node
2. remove node
3. modify node
4. query node
"""

from collections import defaultdict, deque
import os
import os.path as osp
from typing import List, Dict, Set, Tuple, Optional, Union
import numpy as np
import onnx
from onnx import AttributeProto, numpy_helper
from onnx.helper import make_attribute, make_node, make_tensor


class Node:
    def __init__(self, onnx_modifier=None, obj=None, index=None):
        self.onnx_modifier = onnx_modifier
        self.obj = obj
        self.index = index
    
    @property
    def name(self):
        return self.obj.name
    
    @property
    def op_type(self):
        return self.obj.op_type

    @property
    def inputs(self):
        return self.obj.input

    @property
    def outputs(self):
        return self.obj.output

    def check_modifier(self):
        if self.onnx_modifier is None:
            raise RuntimeError("onnx_modifier is not initialized")
    @property
    def prev_nodes(self):
        self.check_modifier()
        return self.onnx_modifier.get_prev_nodes(self)
    
    @property
    def next_nodes(self):
        self.check_modifier()
        return self.onnx_modifier.get_next_nodes(self)

    def set_input(self, index, name):
        assert index < len(self.obj.input), "index out of range"
        self.obj.input[index] = name

    def set_inputs(self, names):
        assert len(names) == len(self.obj.input), "number of inputs does not match"
        assert all(isinstance(name, str) for name in names), "input names must be strings"
        self.obj.input[:] = names

    def set_output(self, index, name):
        assert index < len(self.obj.output), "index out of range"
        self.obj.output[index] = name

    def set_outputs(self, names):
        assert len(names) == len(self.obj.output), "number of outputs does not match"
        assert all(isinstance(name, str) for name in names), "output names must be strings"
        self.obj.output[:] = names

    @property
    def attrs(self):
        attrs = {}
        for attr in self.obj.attribute:
            if attr.type == AttributeProto.FLOAT:  # 1
                value = attr.f
            elif attr.type == AttributeProto.INT:  # 2
                value = attr.i
            elif attr.type == AttributeProto.STRING:  # 3
                value = attr.s.decode('utf-8')
            elif attr.type == AttributeProto.TENSOR:  # 4
                value = numpy_helper.to_array(attr.t)
            elif attr.type == AttributeProto.FLOATS:  # 6
                value = list(attr.floats)
            elif attr.type == AttributeProto.INTS:  # 7
                value = list(attr.ints)
            else:
                value = f"Unsupported type: {attr.type}"
            attrs[attr.name] = value
        return attrs

    def set_attribute(self, name, value, name2attr=None):
        if not name2attr:
            name2attr = {}
            for attr in self.obj.attribute:
                name2attr[attr.name] = attr

        if name in name2attr:
            if isinstance(value, float):
                name2attr[name].f = value
                name2attr[name].type = AttributeProto.FLOAT
            elif isinstance(value, int):
                name2attr[name].i = value
                name2attr[name].type = AttributeProto.INT
            elif isinstance(value, str):
                name2attr[name].s = value.encode('utf-8')
                name2attr[name].type = AttributeProto.STRING
            elif isinstance(value, np.ndarray):
                name2attr[name].ClearField("t")
                name2attr[name].t.CopyFrom(numpy_helper.from_array(value))
            elif isinstance(value, list):
                is_all_float = all(isinstance(x, float) for x in value)
                is_all_int = all(isinstance(x, int) for x in value)
                assert is_all_float or is_all_int
                if is_all_float:
                    name2attr[name].ClearField("floats")
                    name2attr[name].floats.extend(value)
                    name2attr[name].type = AttributeProto.FLOATS
                else:
                    name2attr[name].ClearField("ints")
                    name2attr[name].ints.extend(value)
                    name2attr[name].type = AttributeProto.INTS
        else:
            if isinstance(value, np.ndarray):
                value = numpy_helper.from_array(value)
            self.obj.attribute.append(make_attribute(name, value))

    def set_attributes(self, attr_dict):
        name2attr = {}
        for attr in self.obj.attribute:
            name2attr[attr.name] = attr
        
        for name, value in attr_dict.items():
            self.set_attribute(name, value, name2attr)
        

class Connection:
    def __init__(self, conn_name, onnx_modifier=None):
        self.name = conn_name
        self.onnx_modifier = onnx_modifier

    @property
    def from_node(self):
        return self.onnx_modifier.get_from_node(self.name)
    
    @property
    def to_nodes(self):
        return self.onnx_modifier.get_to_nodes(self.name)


class ONNXModifier:
    def __init__(self, onnx_path):
        self.onnx_path = onnx_path
        
        self.node_map = {}
        self.initializer_map = {}
        self.sparse_initializer_map = {}
        self.connection_map = {}
        self.value_info_map = {}

        self.parse_onnx(self.onnx_path)

    def parse_onnx(self, onnx_path):
        model = onnx.load(onnx_path)

        self.model = model
        self.domain = model.domain
        self.graph = model.graph
        self.ir_version = model.ir_version
        self.mdoel_version = model.model_version
        self.opset_import = model.opset_import

        for i, node in enumerate(self.graph.node):
            self.node_map[node.name] = Node(self, node, i)

        for node in self.graph.node:
            for conn_name in node.input:
                if conn_name not in self.connection_map:
                    self.connection_map[conn_name] = [None, []]
                self.connection_map[conn_name][1].append(node.name)
            for conn_name in node.output:
                if conn_name not in self.connection_map:
                    self.connection_map[conn_name] = [None, []]
                self.connection_map[conn_name][0] = node.name
        
        for i, node in enumerate(self.graph.initializer):
            self.initializer_map[node.name] = [node, i]
        
        for i, node in enumerate(self.graph.sparse_initializer):
            self.sparse_initializer_map[node.name] = [node, i]
        
        for i, conn in enumerate(self.graph.value_info):
            self.value_info_map[conn.name] = [conn, i]

    @property
    def input_names(self):
        return [i.name for i in self.graph.input]

    @property
    def output_names(self):
        return [o.name for o in self.graph.output]
    
    def get_node(self, name_or_index: Union[str, int]):
        """根据节点名称或索引获取节点实例"""

        if isinstance(name_or_index, str):
            if name_or_index in self.node_map:
                return self.node_map.get(name_or_index, None)
        elif isinstance(name_or_index, int):
            if name_or_index < len(self.graph.node):
                return self.node_map.get(
                    self.graph.node[name_or_index].name, None)
            else:
                raise ValueError(f"Node index {name_or_index} out of range")

    def get_initializer(self, name_or_index: Union[str, int]):
        """根据initializer名称或索引获取initializer"""

        if isinstance(name_or_index, str):
            if name_or_index in self.initializer_map:
                return self.initializer_map[name_or_index][0]
            else:
                raise ValueError(f"Initializer {name_or_index} not found")
        elif isinstance(name_or_index, int):
            if name_or_index < len(self.graph.initializer):
                return self.graph.initializer[name_or_index]
            else:
                raise ValueError(
                    f"Initializer index {name_or_index} out of range")

    def get_from_node(self, conn: Union[str, Connection]):
        """获取某条边的输入节点名"""

        if isinstance(conn, str):
            conn_name = conn
        elif isinstance(conn, Connection):
            conn_name = conn.name
        else:
            raise TypeError(f"Invalid connection type {type(conn)}")
        
        name = self.connection_map.get(conn_name, [None, []])[0]
        return self.get_node(name) if name else None

    def get_to_nodes(self, conn: Union[str, Connection]):
        """获取某条边的输出节点"""

        if isinstance(conn, str):
            conn_name = conn
        elif isinstance(conn, Connection):
            conn_name = conn.name
        else:
            raise TypeError(f"Invalid connection type {type(conn)}")
        
        names = self.connection_map.get(conn_name, [None, []])[1]
        return [self.get_node(name) for name in names]

    def get_prev_nodes(self, node: Union[str, Node]):
        """获取某节点的上游输入节点"""
        
        if isinstance(node, str):
            node = self.node_map[node]
        elif isinstance(node, Node):
            pass
        else:
            raise TypeError(f"Invalid node type {type(node)}")
        
        nodes = []
        for conn_name in node.inputs:
            from_node = self.get_from_node(conn_name)
            if from_node:
                nodes.append(from_node)
        return nodes

    def get_next_nodes(self, node: Union[str, Node]):
        """获取某节点的下游节点"""

        if isinstance(node, str):
            node = self.node_map[node]
        elif isinstance(node, Node):
            pass
        else:
            raise TypeError(f"Invalid node type {type(node)}")
        
        nodes = []
        for conn_name in node.outputs:
            to_nodes = self.get_to_nodes(conn_name)
            nodes.extend(to_nodes)
        return nodes

    def create_node(self, op_type, op_name, inputs, outputs, domain="", index=None, **attrs):
        """创建一个新节点"""
        new_node = make_node(op_type, inputs, outputs, op_name, domain, **attrs)
        if index is None:
            self.graph.node.append(new_node)
            new_node = Node(self, new_node, len(self.graph.node) + 1)
        else:
            assert index <= len(self.graph.node), "index out of range"
            self.graph.node.insert(index, new_node)
            for i in range(index + 1, len(self.graph.node)):
                node_name = self.graph.node[i].name
                old_idx = self.node_map[node_name].index
                assert old_idx == i - 1, \
                    f"Node {node_name} index conflict: {old_idx} != {i - 1}"
                self.node_map[node_name].index = i
            new_node = Node(self, new_node, index)
        
        # self.node_map[op_name] = new_node
        return new_node

    def create_initializer(self, name, value: np.ndarray):
        """创建一个 initializer"""
        
        init_node = numpy_helper.from_array(value, name=name)
        self.graph.initializer.append(init_node)
        return init_node

    def connect_node(self, node, inputs_map, outputs_map):
        """将某个节点与其上下游节点连接起来
        
        Args:
            node: Node
            inputs_map: [(node0, out_idx0), (node1, out_idx1), ...]
            outputs_map: [(node0, in_idx0), (node1, in_idx1), ...]
        """
        
        # 在连接 A -> B 时，若 A 的输出名与 B 的输入名冲突时，优先使用 A 的输出名，
        # 即：B.input[i] = A.output[j]
        
        # input_names = []
        # for n, i in inputs_map:
        #     if isinstance(n, str):
        #         n = self.node_map[n]
        #     assert i < len(n.outputs), \
        #         f"output index {i} out of node {n.name} outputs range"
        #     input_names.append(n.outputs[i])
        # if input_names:
        #     node.set_inputs(input_names)
        for i, (n, j) in enumerate(inputs_map):
            if isinstance(n, str):
                n = self.node_map[n]
            assert j < len(n.outputs), \
                f"output index {i} out of node {n.name} outputs range"
            node.set_input(i, n.outputs[j])

        for name, (n, i) in zip(node.outputs, outputs_map):
            if isinstance(n, str):
                n = self.node_map[n]
            assert i < len(n.outputs), \
                f"output index {i} out of node {n.name} outputs range"
            n.set_output(i, name)
        
    def pop_node(self, node: Union[str, Node, int]):
        """根据节点名称或索引移除节点"""

        if isinstance(node, str):
            node = self.node_map.get(node, None)
            if node is None:
                return None
            index = node.index
            assert node.name == self.graph.node[index].name
        elif isinstance(node, int):
            if node >= len(self.graph.node):
                raise ValueError(f"Node index {node} out of range")
            index = node
        elif isinstance(node, Node):
            index = node.index
        else:
            raise ValueError(f"Invalid node name or index: {node}")
        
        for i in range(index + 1, len(self.graph.node)):
            node = self.graph.node[i]
            self.node_map[node.name].index -= 1

        return self.graph.node.pop(index)

    def remove_nodes(self, nodes: List[str | Node]):
        """同时删除多个节点"""

        indices = set()
        _nodes = []
        invalid_nodes = set()
        for node in nodes:
            if isinstance(node, str):
                if node in self.node_map:
                    node = self.node_map[node]
                    if node.index not in indices:
                        _nodes.append(node)
                        indices.add(node.index)
                else:
                    invalid_nodes.add(node)
            elif isinstance(node, Node):
                if node.index not in indices:
                    _nodes.append(node)
                    indices.add(node.index)
            else:
                invalid_nodes.add(node)
        
        _nodes.sort(key=lambda x:x.index, reverse=True)
        for node in _nodes:
            self.pop_node(node)
        
        # print(f"{len(nodes) - len(invalid_nodes)} nodes have been removed.")
        # if len(invalid_nodes) > 0:
        #     print(f"find {len(invalid_nodes)} invalid nodes:\n", invalid_nodes)

    def update_map(self):
        """更新connection_map与node_map"""
        self.node_map.clear()
        self.connection_map.clear()
        for i, node in enumerate(self.graph.node):
            self.node_map[node.name] = Node(self, node, i)

            for conn_name in node.input:
                if conn_name not in self.connection_map:
                    self.connection_map[conn_name] = [None, []]
                self.connection_map[conn_name][1].append(node.name)
            for conn_name in node.output:
                if conn_name not in self.connection_map:
                    self.connection_map[conn_name] = [None, []]
                self.connection_map[conn_name][0] = node.name

    def find_unuseful_nodes(self):
        """寻找没有用到的节点"""

        end_names = set()
        for output_name in self.output_names:
            end_names.add(self.get_from_node(output_name).name)

        unuseful_names = set()
        for node in self.node_map.values():
            if node.name in end_names:
                continue
            next_nodes = self.get_next_nodes(node)
            if len(next_nodes) == 0:
                unuseful_names.add(node.name)

        q = deque([self.node_map[name] for name in unuseful_names])
        while len(q) != 0:
            node = q.popleft()
            prev_nodes = self.get_prev_nodes(node)
            for node1 in prev_nodes:
                next_nodes = self.get_next_nodes(node1)
                next_names = set([node2.name for node2 in next_nodes])
                if (next_names - end_names).issubset(unuseful_names):
                    q.append(node1)
                    unuseful_names.add(node1.name)
        
        unuseful_nodes = [self.node_map[name] for name in unuseful_names]
        return unuseful_nodes

    def  remove_trash(self):
        """
        1. 移除无用的节点
        2. 移除无用的initializer
        3. 移除没有输入节点的connection
        """
        self.update_map()
        unuseful_nodes = self.find_unuseful_nodes()
        # print(f"Find unuseful {len(unuseful_nodes)} nodes:", 
        #       [op.name for op in unuseful_nodes])
        self.remove_nodes(unuseful_nodes)

    def save(self, save_path, save_as_external_data=False, 
             all_tensors_to_one_file=True):
        
        external_data_path = osp.basename(save_path) + '.data'
        if osp.isfile(external_data_path):
            os.remove(external_data_path)

        onnx.save(self.model, 
                  save_path, 
                  save_as_external_data=save_as_external_data, 
                  all_tensors_to_one_file=all_tensors_to_one_file, 
                  location=external_data_path, 
                  size_threshold=1024, 
                  convert_attribute=False)
