package docker

import (
	"get-container/utils"

	"context"
	"errors"
	"fmt"
	"github.com/moby/moby/api/types/container"
	"github.com/moby/moby/client"
	"os"
	"regexp"
	"strings"
	"sync"
	"time"
)

/**
有两种方法获取进程属于哪个容器
1. 通过查询pid命令空间
2. 通过查询进程的cgroup
*/

type FindCIDMethod string

const (
	ByCgroup FindCIDMethod = "byCGroup"
	ByPidNS  FindCIDMethod = "byPidNS"
)

var (
	ReDocker                      = regexp.MustCompile(`^.*docker[-/]([0-9a-z]*)(?:|.*)`)
	ContainerInfo *ContainersInfo = nil
)

type ContainersInfo struct {
	lock        sync.RWMutex // 读写锁，防止对Info的并发写
	time        time.Time    // 记录写入Info的时间
	inspectInfo map[string]container.InspectResponse
	listInfo    map[string]container.Summary
}

func (info *ContainersInfo) Update() error {
	info.lock.Lock()
	defer info.lock.Unlock()
	i, s, err := getContainerInfo()
	if err != nil {
		return err
	}
	info.inspectInfo = i
	info.listInfo = s
	info.time = time.Now()
	return nil
}

func (info *ContainersInfo) Get() (map[string]container.InspectResponse, sync.Locker) {
	rl := info.lock.RLocker()
	rl.Lock()
	return info.inspectInfo, rl
}

func init() {
	_ = initContainerInfo()
}

func initContainerInfo() error {
	inspect, lists, err := getContainerInfo()
	if err != nil {
		return err
	}
	ContainerInfo = &ContainersInfo{
		lock:        sync.RWMutex{},
		time:        time.Now(),
		inspectInfo: inspect,
		listInfo:    lists,
	}
	return nil
}

// FindContainerIdByPid 根据pid获取该进程属于哪个docker容器，返回容器id，如果为nil，表示找不到容器id
func FindContainerIdByPid(pid uint64, method FindCIDMethod) (*string, error) {
	switch method {
	case ByPidNS:
		return findContainerIdByNS(pid)
	case ByCgroup:
		return findContainerIdByCgroup(pid)
	default:
		return nil, fmt.Errorf("unknown method: %s", method)
	}
}

func FindContainerIdByPidBatch(pids []uint64, method FindCIDMethod) (map[uint64]string, error) {
	if pids == nil || len(pids) == 0 {
		return nil, nil
	}
	switch method {
	case ByPidNS:
		return findContainerIdByNSBatch(pids)
	case ByCgroup:
		return findContainerIdByCgroupBatch(pids)
	default:
		return nil, fmt.Errorf("unknown method: %s", method)
	}
}

// findContainerIdByPidCgroup 通过cgroup查询docker容器id
func findContainerIdByCgroup(pid uint64) (*string, error) {
	content, err := os.ReadFile(fmt.Sprintf("/proc/%d/cgroup", pid))
	if err != nil {
		return nil, err
	}
	contentStr := strings.Trim(string(content), "\n")
	if len(contentStr) == 0 {
		return nil, errors.New("process's cgroup not found")
	}
	lines := strings.Split(contentStr, "\n")
	var target string
	if len(lines) > 1 {
		// 如果有多行，解析有pids的行
		for _, line := range lines {
			if strings.Contains(line, "pids") {
				target = strings.TrimSpace(line)
				break
			}
		}
		if target == "" {
			return nil, errors.New("process's cgroup not found pids line")
		}
	} else {
		// 如果是单行，直接解析
		target = strings.TrimSpace(lines[0])
	}
	target = strings.TrimSpace(target)
	if !strings.Contains(target, "docker") {
		return nil, errors.New("process's cgroup is not create by docker")
	}
	if ReDocker.MatchString(target) {
		fields := ReDocker.FindStringSubmatch(target)
		if len(fields) < 2 {
			return nil, errors.New("process's cgroup is not create by docker")
		}
		cid := fields[1]
		return &cid, nil
	} else {
		return nil, errors.New("process's cgroup is not create by docker")
	}
}

func findContainerIdByCgroupBatch(pids []uint64) (map[uint64]string, error) {
	results := make(map[uint64]string)
	for _, pid := range pids {
		str, err := findContainerIdByCgroup(pid)
		if err != nil {
			return nil, err
		}
		s := *str
		results[pid] = s
	}
	return results, nil
}

// findContainerIdByNS 通过pid命名空间查询docker容器id
func findContainerIdByNS(pid uint64) (*string, error) {
	ns, err := utils.GetPidNS(pid)
	if err != nil {
		return nil, err
	}
	if ContainerInfo == nil {
		innerErr := initContainerInfo()
		if innerErr != nil {
			return nil, innerErr
		}
	} else {
		if innerErr := ContainerInfo.Update(); innerErr != nil {
			return nil, innerErr
		}
	}
	info, lock := ContainerInfo.Get()
	defer lock.Unlock()
	for k, v := range info {
		containerNs, innerErr := utils.GetPidNS(uint64(v.State.Pid))
		if innerErr != nil {
			continue
		}
		if containerNs == ns {
			cid := k
			return &cid, nil
		}
	}
	return nil, nil
}

func findContainerIdByNSBatch(pids []uint64) (map[uint64]string, error) {
	if ContainerInfo == nil {
		innerErr := initContainerInfo()
		if innerErr != nil {
			return nil, innerErr
		}
	} else {
		if innerErr := ContainerInfo.Update(); innerErr != nil {
			return nil, innerErr
		}
	}
	info, lock := ContainerInfo.Get()
	defer lock.Unlock()
	results := make(map[uint64]string)
	ns2cid := make(map[uint64]string)
	for k, v := range info {
		containerNs, innerErr := utils.GetPidNS(uint64(v.State.Pid))
		if innerErr != nil {
			return nil, innerErr
		}
		ns2cid[containerNs] = k
	}
	for _, pid := range pids {
		ns, err := utils.GetPidNS(pid)
		if err != nil {
			continue
		}
		if cid, ok := ns2cid[ns]; ok {
			results[pid] = cid
		}
	}
	return results, nil
}

// getContainerInfo 获取所有正在运行的docker容器的详细信息
func getContainerInfo() (map[string]container.InspectResponse, map[string]container.Summary, error) {
	cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
	if err != nil {
		return nil, nil, err
	}
	defer func() {
		_ = cli.Close()
	}()
	containerSum, err := cli.ContainerList(context.Background(), client.ContainerListOptions{All: false})
	if err != nil {
		return nil, nil, err
	}
	inspects := make(map[string]container.InspectResponse)
	lists := make(map[string]container.Summary)
	for _, c := range containerSum {
		inspect, innerErr := cli.ContainerInspect(context.Background(), c.ID)
		if innerErr != nil {
			return nil, nil, innerErr
		}
		inspects[c.ID] = inspect
		lists[c.ID] = c
	}
	return inspects, lists, nil
}
