package main

import (
	"context"
	"fmt"
	"get-container/docker"
	"log"
	"os"
	"regexp"
	"slices"
	"sort"
	"strconv"
	"strings"
	"time"

	"get-container/cmd/pid_of_docker/logic"

	"github.com/charmbracelet/lipgloss"
	"github.com/moby/moby/client"
	"github.com/spf13/pflag"
)

/*
容器名          容器id       创建者      使用的卡   容器中的进程数
ContainerName  ContainerID  CreateUser  UsedCard  PidsInContainer
*/

type Cinfo struct {
	Name  string
	Id    string
	User  []string
	Pid   []int
	Cards []*logic.CardInfo
}

// Format 格式化容器信息
// 容器名，容器ID，创建者，使用的卡，容器中的进程数
func (c *Cinfo) Format() ([5]string, [5]int) {
	var user, card string = " - ", " - "
	if len(c.User) != 0 {
		user = strings.Join(RemoveDuplicates(c.User), ",")
	}
	if len(c.Cards) != 0 {
		ss := make([]string, 0, len(c.Cards))
		for _, v := range c.Cards {
			ss = append(ss, fmt.Sprintf("%d", v.Index))
		}
		slices.Sort(ss)
		card = strings.Join(ss, ",")
	}
	rs := [5]string{c.Name, c.Id[:12], user, card, fmt.Sprintf("%d", len(c.Pid))}
	rl := [5]int{}
	for k, v := range rs {
		rl[k] = len(v)
	}
	return rs, rl
}

var (
	RegUser       = regexp.MustCompile(`^(?i)/public[0-9]*/home/([0-9a-z_]+)(?:|/.*)$`)
	RegUserPublic = regexp.MustCompile(`^(?i)/public[0-9]*/home/(?:public_user|locauser|localuser)/([0-9a-z_]+)(?:|/.*)$`)

	flagHelp = pflag.BoolP("help", "h", false, "show usage")
)

func main() {

	pflag.Parse()
	if *flagHelp {
		fmt.Println("这个工具用于查看docker容器的创建者和使用计算卡的情况")
		fmt.Println("容器创建者由容器挂载的目录和标签com.sugon.username推导出来")
		fmt.Println("本工具支持查看nvidia和hycu")
		os.Exit(0)
	}

	cli, err := docker.GetDockerClient()
	if err != nil {
		log.Fatalf("can't connect to docker daemon: %v", err)
	}

	// docker ps 超时控制
	ctx, canfunc := context.WithTimeout(context.Background(), time.Second)
	defer canfunc()

	csum, err := cli.ContainerList(ctx, client.ContainerListOptions{})
	if err != nil {
		log.Printf("error get container list: %v \n", err)
		cli.Close()
		os.Exit(1)
	}
	result := make(map[string]*Cinfo)
	for _, i := range csum {
		c := Cinfo{}
		result[i.ID] = &c
		c.Id = i.ID
		c.Cards = make([]*logic.CardInfo, 0)
		if len(i.Names) > 0 {
			c.Name = strings.TrimPrefix(i.Names[0], "/")
		}
		if name, ok := i.Labels["com.sugon.username"]; ok {
			c.User = make([]string, 0, 1)
			c.User = append(c.User, name)
		} else {
			c.User = make([]string, 0, 4)
			for _, v := range i.Mounts {
				if RegUserPublic.MatchString(v.Source) {
					f := RegUserPublic.FindStringSubmatch(v.Source)
					if len(f) >= 2 {
						c.User = append(c.User, f[1])
					}
				} else if RegUser.MatchString(v.Source) {
					f := RegUser.FindStringSubmatch(v.Source)
					if len(f) >= 2 {
						c.User = append(c.User, f[1])
					}
				}
			}
		}
		// docker top 超时控制
		ctxTout, ctxFunc := context.WithTimeout(context.Background(), time.Second)
		top, err := cli.ContainerTop(ctxTout, i.ID, nil)
		if err != nil {
			ctxFunc()
			fmt.Printf("get pid of container %s timeout: %v \n", i.ID, err)
			continue
		}
		ctxFunc()
		index := slices.Index(top.Titles, "PID")
		if index == -1 {
			fmt.Printf("cat't find PID field in ContainerID: %s \n", i.ID)
			continue
		}
		c.Pid = make([]int, 0, len(top.Processes))
		for _, v := range top.Processes {
			pid, err := strconv.Atoi(v[index])
			if err == nil {
				c.Pid = append(c.Pid, pid)
			}
		}
	}
	cli.Close()

	var cardInfo map[int]*logic.CardInfo

	if a := logic.DCUInfo(); a != nil {
		cardInfo = a
	} else if a := logic.NVIDIAInfo(); a != nil {
		cardInfo = a
	}
	if cardInfo != nil {
		for _, v := range result {
			for _, vin := range cardInfo {
				if contain(vin.Pids, v.Pid) {
					v.Cards = append(v.Cards, vin)
				}
			}
		}
	}
	printInfo(result)
}

func contain(s1, s2 []int) bool {
	if len(s1) == 0 || len(s2) == 0 {
		return false
	}
	for _, v := range s1 {
		if slices.Contains(s2, v) {
			return true
		}
	}
	return false
}

/*
容器名          容器id       创建者      使用的卡   容器中的进程数
ContainerName  ContainerID  CreateUser  UsedCard  PidsInContainer
*/

// printInfo 打印信息
func printInfo(info map[string]*Cinfo) {
	titleStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("#6071f1"))
	s1 := lipgloss.NewStyle().Foreground(lipgloss.Color("#82ff60"))
	s2 := lipgloss.NewStyle().Foreground(lipgloss.Color("#54fff6"))
	s3 := lipgloss.NewStyle().Foreground(lipgloss.Color("#eef867"))
	title := []string{"ContainerName", "ContainerID", "CreateUser", "UsedCard", "PidsInContainer"}
	maxWidth := []int{0, 0, 0, 0, 0}
	for k, v := range title {
		maxWidth[k] = len(v)
	}
	out := make([][5]string, 0, len(info))
	for _, v := range info {
		a, b := v.Format()
		out = append(out, a)
		for x, y := range b {
			maxWidth[x] = max(maxWidth[x], y)
		}
	}
	t := fmt.Sprintf("%s %s %s %s %s", formatStr(title[0], maxWidth[0]), formatStr(title[1], maxWidth[1]), formatStr(title[2], maxWidth[2]), formatStr(title[3], maxWidth[3]), formatStr(title[4], maxWidth[4]))
	fmt.Println(titleStyle.Render(t))
	sort.Slice(out, func(i, j int) bool {
		return out[i][1] < out[j][1]
	})
	for k, v := range out {
		o := fmt.Sprintf("%s %s %s %s %s", formatStr(v[0], maxWidth[0]),
			formatStr(v[1], maxWidth[1]),
			formatStr(v[2], maxWidth[2]),
			formatStr(v[3], maxWidth[3]),
			formatStr(v[4], maxWidth[4]))
		switch k % 3 {
		case 0:
			fmt.Println(s1.Render(o))
		case 1:
			fmt.Println(s2.Render(o))
		case 2:
			fmt.Println(s3.Render(o))
		}
	}
}

func formatStr(raw string, l int) string {
	lstr := len(raw)
	if l >= lstr {
		return fmt.Sprintf("%s%s", raw, strings.Repeat(" ", l-lstr))
	}
	return raw[:l]
}

func RemoveDuplicates[T comparable](s []T) []T {
	m := make(map[T]struct{})
	index := 0
	for _, v := range s {
		if _, have := m[v]; !have {
			m[v] = struct{}{}
			s[index] = v
			index++
		}
	}
	return s[:index]
}
