package utils

type Tree[K comparable, V any] struct {
	root     *TreeNode[K, V]
	size     int
	flatten  map[K]*TreeNode[K, V]
	maxDepth int // 最大深度
}

func NewTree[K comparable, V any]() *Tree[K, V] {
	return &Tree[K, V]{
		root:    nil,
		size:    0,
		flatten: make(map[K]*TreeNode[K, V]),
	}
}

func (t *Tree[K, V]) Size() int {
	return t.size
}

// AddChildNode 向树中添加节点，如果parent为nill，表示添加此为根节点
// 只能添加一次根节点
func (t *Tree[K, V]) AddChildNode(id K, data V, parent *TreeNode[K, V]) *TreeNode[K, V] {
	n := newTreeNode(id, data, nil)
	if parent == nil {
		if t.root == nil {
			t.root = n
			t.root.depth = 0
			t.size++
		} else {
			return nil
		}
	} else {
		t.size++
		parent.child[n.id] = n
		n.parent = parent
		n.depth = parent.depth + 1
		t.maxDepth = max(t.maxDepth, n.depth)
	}
	t.flatten[n.id] = n
	return n
}

// DelNode 删除节点，如果节点不是叶子节点，那么会删除其所有子节点
func (t *Tree[K, V]) DelNode(id K) *TreeNode[K, V] {
	n, have := t.flatten[id]
	if !have {
		return nil
	}
	delete(t.flatten, id)
	delete(n.parent.child, id)
	n.parent = nil
	l := n.flattenToList()
	for _, v := range l {
		v.depth = 0
		delete(t.flatten, v.id)
	}
	t.size -= len(l)
	return n
}

func (t *Tree[K, V]) GetNodeById(id K) *TreeNode[K, V] {
	return t.flatten[id]
}

func (t *Tree[K, V]) FlattenToList() []*TreeNode[K, V] {
	return t.root.flattenToList()
}

func (t *Tree[K, V]) GetRoot() *TreeNode[K, V] {
	return t.root
}

func (t *Tree[K, V]) GetDepth() int {
	return t.maxDepth
}

type TreeNode[K comparable, V any] struct {
	parent *TreeNode[K, V]       // 父节点
	child  map[K]*TreeNode[K, V] // 子节点
	Data   V                     // 数据
	id     K                     // 唯一性id
	depth  int                   // 深度，为根节点时，depth为0，依次递增
}

func (t *TreeNode[K, V]) GetId() K {
	return t.id
}

func (t *TreeNode[K, V]) GetParent() *TreeNode[K, V] {
	return t.parent
}

func newTreeNode[K comparable, V any](key K, data V, parent *TreeNode[K, V]) *TreeNode[K, V] {
	return &TreeNode[K, V]{
		parent: parent,
		Data:   data,
		id:     key,
		child:  make(map[K]*TreeNode[K, V]),
		depth:  0,
	}
}

// FlattenToList 将树展平为切片
func (t *TreeNode[K, V]) flattenToList() []*TreeNode[K, V] {
	result := make([]*TreeNode[K, V], 0, 128)
	for _, c := range t.childs() {
		result = append(result, c)                    // 添加这个子节点
		result = append(result, c.flattenToList()...) // 添加这个子节点的子节点
	}
	if t.parent == nil { // parent为nil，添加自己
		result = append(result, t)
	}
	return result
}

// Childs 列出子节点
func (t *TreeNode[K, V]) childs() []*TreeNode[K, V] {
	l := len(t.child)
	result := make([]*TreeNode[K, V], 0, l)
	if l == 0 {
		return result
	}
	for _, v := range t.child {
		result = append(result, v)
	}
	return result
}
