package query

import (
	"bytes"
	"dcu-container-toolkit/internal/hydcu"
	"fmt"
	"os/exec"
	"regexp"
	"sort"
	"strconv"
	"strings"
)

type DcuInfo struct {
	DcuId         int
	Pid           []string
	ContainerName []string
	Uuid          string
}

type DCUProcess struct {
	Pid   string
	Index string
}

func parseDCUsList(dcus string) ([]int, []string, []string, error) {
	// isHexString checks if a string contains only hexadecimal characters
	isHexString := func(s string) bool {
		if len(s) == 0 {
			return false
		}
		for _, c := range s {
			if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')) {
				return false
			}
		}
		return true
	}

	validDCUs := []int{}
	invalidDCUs := []string{}
	invalidDCUsRange := []string{}

	dcusInfo, err := hydcu.GetHYDCUs()
	if err != nil {
		return []int{}, []string{}, []string{}, fmt.Errorf("Failed to get DCU info, Error: %v", err)
	}

	if dcus == "all" || dcus == "All" || dcus == "ALL" {
		for i := 0; i < len(dcusInfo); i++ {
			validDCUs = append(validDCUs, i)
		}
		return validDCUs, []string{}, []string{}, nil
	}

	uuidToDCUIdMap, err := hydcu.GetUniqueIdToDeviceIndexMap()

	if err != nil {
		fmt.Printf("Failed to get UUID to DCU Id mappings: %v", err)
		uuidToDCUIdMap = make(map[string][]int)
	}

	for _, c := range strings.Split(dcus, ",") {
		if strings.HasPrefix(c, "0x") || strings.HasPrefix(c, "0X") ||
			(len(c) > 8 && isHexString(c)) {
			uuid := strings.ToLower(c)
			if !strings.HasPrefix(uuid, "0x") {
				uuid = "0x" + uuid
			}
			if gpuIds, exists := uuidToDCUIdMap[uuid]; exists {
				validDCUs = append(validDCUs, gpuIds...)
			} else {
				uuid = strings.TrimPrefix(uuid, "0x")
				if dcuIds, exists := uuidToDCUIdMap[uuid]; exists {
					validDCUs = append(validDCUs, dcuIds...)
				} else {
					invalidDCUs = append(invalidDCUs, c)
				}
			}
		} else if strings.Contains(c, "-") {
			devsRange := strings.SplitN(c, "-", 2)
			start, err0 := strconv.Atoi(devsRange[0])
			end, err1 := strconv.Atoi(devsRange[1])
			if err0 != nil || err1 != nil ||
				start < 0 || end < 0 || start > end {
				invalidDCUsRange = append(invalidDCUsRange, c)
			} else {
				for i := start; i <= end; i++ {
					if i < len(dcusInfo) {
						validDCUs = append(validDCUs, i)
					} else {
						invalidDCUs = append(invalidDCUs, strconv.Itoa(i))
					}
				}
			}
		} else {
			i, err := strconv.Atoi(c)
			if err == nil {
				if i >= 0 && i < len(dcusInfo) {
					validDCUs = append(validDCUs, i)
				} else {
					invalidDCUs = append(invalidDCUs, c)
				}
			} else {
				invalidDCUs = append(invalidDCUs, c)
			}
		}
	}

	sort.Ints(validDCUs)

	return validDCUs, invalidDCUs, invalidDCUsRange, nil
}


func CheckHySmi() (string, error) {
	cmd := exec.Command("whereis", "hy-smi")
	var out bytes.Buffer
	cmd.Stdout = &out

	err := cmd.Run()
	if err != nil {
		return "", err
	}
	output := out.String()
	fields := strings.Fields(output)
	if len(fields) > 1 {
		return fields[1], nil
	}

	return "", nil
}

func Exec(hysmi string) ([]DCUProcess, error) {

	cmd := exec.Command(hysmi, "--showpids")
	var out bytes.Buffer
	cmd.Stdout = &out
	var results []DCUProcess
	err := cmd.Run()
	if err != nil {
		return results, fmt.Errorf("Failed to run hy-smi command, Error: %v", err)
	}

	output := out.String()

	reBlock := regexp.MustCompile(`PID:\s*(\d+)[\s\S]*?[H|D]CU Index:\s*(.*)`)
	matches := reBlock.FindAllStringSubmatch(output, -1)

	for _, m := range matches {
		pid := m[1]
		index := strings.TrimSpace(m[2])
		index = strings.Trim(index, "[]' ")
		results = append(results, DCUProcess{Pid: pid, Index: index})
	}

	return results, nil
}

func QueryName(pid string) (string, error) {
	cmd := exec.Command("cat", "/proc/"+pid+"/cgroup")
	var out bytes.Buffer
	cmd.Stdout = &out
	err := cmd.Run()
	if err != nil {
		return "", err
	}

	output := out.String()
	re := regexp.MustCompile(`docker/([0-9a-f]{64})`)
	matches := re.FindStringSubmatch(output)
	if len(matches) < 2 {
		return "", nil
	}
	containerID := matches[1]

	cmd2 := exec.Command("docker", "ps", "-a", "--format", "{{.ID}} {{.Names}}")
	var psOut bytes.Buffer
	cmd2.Stdout = &psOut
	if err := cmd2.Run(); err != nil {
		return "", err
	}

	lines := strings.Split(psOut.String(), "\n")
	containerName := ""
	for _, line := range lines {
		fields := strings.Fields(line)
		if len(fields) >= 2 && strings.HasPrefix(containerID, fields[0]) {
			containerName = fields[1]
			break
		}
	}
	return containerName, nil
}

func ShowStatus(dcus string) error {
	hySmiPath, err := CheckHySmi()
	if err != nil {
		return fmt.Errorf("Failed to check hy-smi path, Error: %v", err)
	}
	processes, err := Exec(hySmiPath)
	if err != nil {
		return fmt.Errorf("Failed to run hy-smi command, Error: %v", err)
	}
	validDCUs, _, _, err := parseDCUsList(dcus)
	if err != nil {
		return fmt.Errorf("Failed to parse DCUs list, Error: %v", err)
	}

	var dcuinfos = make(map[int]DcuInfo)
	uuidToDCUIdMap, err := hydcu.GetUniqueIdToDeviceIndexMap()
	if err != nil {
		return fmt.Errorf("Failed to get UUID to DCU Id mappings: %v", err)
	}

	for _, dcu := range validDCUs {
		for uuid, dcuIds := range uuidToDCUIdMap {
			if strings.HasPrefix(uuid, "0x") || strings.HasPrefix(uuid, "0X") {
				uuid = uuid[2:]
			}
			uuid = "0x" + strings.ToUpper(uuid)
			if dcuIds[0] == dcu {
				dcuinfos[dcu] = DcuInfo{DcuId: dcu, Uuid: uuid}
				break
			}
		}
	}

	for _, process := range processes {
		index, err := strconv.Atoi(process.Index)
		if err != nil {
			continue
		}
		if dcu, exists := dcuinfos[index]; exists {
			dcu.Pid = append(dcu.Pid, process.Pid)
			name, err := QueryName(process.Pid)
			if err != nil {
				continue
			}
			dcu.ContainerName = append(dcu.ContainerName, name)
			dcuinfos[index] = dcu // 注意：结构体是值类型，需重新赋值
		}
	}

	fmt.Println(strings.Repeat("-", 120))
        fmt.Printf("%-40s%-50s%-20s\n", "DCU Id", "UUID", "Container Names")
	fmt.Println(strings.Repeat("-", 120))
	for dcuId := range dcuinfos {
		for idx, name := range dcuinfos[dcuId].ContainerName {
			if idx == 0 {
				fmt.Printf("%-40v%-50s%-20v\n", dcuId, dcuinfos[dcuId].Uuid, name)
			} else {
				fmt.Printf("%-40v%-50s%-20v\n", "", "", name)
			}
		}

	}
	fmt.Println(strings.Repeat("-", 120))
	return nil
}

