package hydcu

import (
	"bufio"
	"errors"
	"github.com/golang/glog"
	"io/ioutil"
	"os"
	"path/filepath"
	"regexp"
	"strconv"
	"strings"
)

var topoDrmRenderMinorRe = regexp.MustCompile(`drm_render_minor\s(\d+)`)
var topoUniqueIdRe = regexp.MustCompile(`unique_id\s(\d+)`)

func GetDevIdsFromTopology(topoRootParam ...string) map[int]string {
	topoRoot := "/sys/class/kfd/kfd"
	if len(topoRootParam) == 1 {
		topoRoot = topoRootParam[0]
	}

	renderDevIds := make(map[int]string)
	var nodeFiles []string
	var err error
	if nodeFiles, err = filepath.Glob(topoRoot + "/topology/nodes/*/properties"); err != nil {
		glog.Fatalf("glob error: %s", err)
		return renderDevIds
	}

	for _, nodeFile := range nodeFiles {
		glog.Info("Parsing  " + nodeFile)
		v, e := ParseTopologyProperties(nodeFile, topoDrmRenderMinorRe)
		if e != nil {
			glog.Error(e)
			continue
		}

		if v <= 0 {
			continue
		}

		devID, e := ParseTopologyPropertiesString(nodeFile, topoUniqueIdRe)
		if e != nil {
			glog.Error(e)
			continue
		}

		renderDevIds[int(v)] = devID
	}
	/*
	   renderDevIds={128:"8324688932758364225",129:"8324688932758358177"}
	*/
	return renderDevIds
}

func ParseTopologyProperties(path string, re *regexp.Regexp) (int64, error) {
	f, e := os.Open(path)
	if e != nil {
		return 0, e
	}
	defer f.Close()
	e = errors.New("Topology property not found. Regex: " + re.String())
	v := int64(0)
	scanner := bufio.NewScanner(f)
	for scanner.Scan() {
		m := re.FindStringSubmatch(scanner.Text())
		if m == nil {
			continue
		}

		v, e = strconv.ParseInt(m[1], 0, 64)
		break
	}
	return v, e
}

func ParseTopologyPropertiesString(path string, re *regexp.Regexp) (string, error) {
	f, e := os.Open(path)
	if e != nil {
		return "", e
	}
	defer f.Close()

	e = errors.New("Topology property not found. Regex: " + re.String())
	v := ""
	scanner := bufio.NewScanner(f)
	for scanner.Scan() {
		m := re.FindStringSubmatch(scanner.Text())
		if m == nil {
			continue
		}

		v = m[1]
		e = nil
		break
	}
	return v, e
}

func GetNodeIdsFromTopology(topoRootParam ...string) map[int]int {
	topoRoot := "/sys/class/kfd/kfd"
	if len(topoRootParam) == 1 {
		topoRoot = topoRootParam[0]
	}

	renderNodeIds := make(map[int]int)
	var nodeFiles []string
	var err error

	if nodeFiles, err = filepath.Glob(topoRoot + "/topology/nodes/*/properties"); err != nil {
		glog.Fatalf("glob error: %s", err)
		return renderNodeIds
	}

	for _, nodeFile := range nodeFiles {
		glog.Info("Parsing " + nodeFile)
		v, e := ParseTopologyProperties(nodeFile, topoDrmRenderMinorRe)
		if e != nil {
			glog.Error(e)
			continue
		}

		if v <= 0 {
			continue
		}

		nodeIndex := filepath.Base(filepath.Dir(nodeFile))

		nodeId, err := strconv.Atoi(nodeIndex)
		if err != nil {
			glog.Errorf("Failed to convert node index %s to int: %v", nodeIndex, err)
			continue
		}
		renderNodeIds[int(v)] = nodeId
	}
	/*
	   renderNodeIds={128:4,129:5}
	*/
	return renderNodeIds
}
func GetHYDCUs() map[string]map[string]interface{} {
	matches, err := filepath.Glob("/sys/module/hy*cu/drivers/pci:hy*cu/[0-9a-fA-F][0-9a-fA-F][0-9a-fA-F][0-9a-fA-F]:*")
	if err != nil {
		glog.Warning("Failed to find hydcu driver directory: %s", err)
		return make(map[string]map[string]interface{})

	}

	devID := ""
	devices := make(map[string]map[string]interface{})
	card, renderD, nodeId := 0, 128, 0
	renderDevIds := GetDevIdsFromTopology()
	renderNodeIds := GetNodeIdsFromTopology()
	/*
	   renderDevIds={128:"8324688932758364225",129:"8324688932758358177"}
	*/
	/*
	   renderNodeIds={128:4,129:5}
	*/
	for _, path := range matches {
		computePartitionFile := filepath.Join(path, "current_compute_partition")
		memoryPartitionFile := filepath.Join(path, "current_memory_partition")
		numaNodeFile := filepath.Join(path, "numa_node")

		computePartitionType, memoryPartitionType := "", ""
		numaNode := -1

		if data, err := ioutil.ReadFile(computePartitionFile); err == nil {
			computePartitionType = strings.ToLower(strings.TrimSpace(string(data)))
		} else {
			glog.Warningf("Failed to read 'current_compute_partition' file at %s: %s", computePartitionFile, err)
		}

		if data, err := ioutil.ReadFile(memoryPartitionFile); err == nil {
			memoryPartitionType = strings.ToLower(strings.TrimSpace(string(data)))
		} else {
			glog.Warningf("Failed to read 'current_memory_partition' file at %s: %s", memoryPartitionFile, err)
		}

		if data, err := ioutil.ReadFile(numaNodeFile); err == nil {
			numaNodeStr := strings.TrimSpace(string(data))
			numaNode, err = strconv.Atoi(numaNodeStr)
			if err != nil {
				glog.Warningf("Failed to convert 'numa_node' value to int: %s", err)
				continue
			}
		} else {
			glog.Warningf("Failed to read 'numa_node' file at %s: %s", numaNodeFile, err)
			continue
		}

		glog.Info(path)
		devPaths, _ := filepath.Glob(path + "/drm/*")

		for _, devPath := range devPaths {
			switch name := filepath.Base(devPath); {
			case name[0:4] == "card":
				card, _ = strconv.Atoi(name[4:])
				//card = 0
			case name[0:7] == "renderD":
				renderD, _ = strconv.Atoi(name[7:])
				//renderD = 128
				if val, exists := renderDevIds[renderD]; exists {
					devID = val
					//devID = 8324688932758364225
				}
				if id, exists := renderNodeIds[renderD]; exists {
					nodeId = id
					//nodeId = 2
				}
			}
			devices[filepath.Base(path)] = map[string]interface{}{
				"card":                 card,
				"renderD":              renderD,
				"devID":                devID,
				"computePartitionType": computePartitionType,
				"memoryPartitionType":  memoryPartitionType,
				"numaNode":             numaNode,
				"nodeId":               nodeId,
			}
		}
	}
	platformMatches, _ := filepath.Glob("/sys/devices/platform/amdgpu_xcp_*")

	for _, path := range platformMatches {
		glog.Info(path)
		devPaths, _ := filepath.Glob(path + "/drm/*")

		computePartitionType, memoryPartitionType := "", ""
		numaNode := -1

		for _, devPath := range devPaths {
			switch name := filepath.Base(devPath); {
			case name[0:4] == "card":
				card, _ = strconv.Atoi(name[4:])
			case name[0:7] == "renderD":
				renderD, _ = strconv.Atoi(name[7:])
				if val, exists := renderDevIds[renderD]; exists {
					devID = val
				}
				// Set the computePartitionType and memoryPartitionType from the real GPU or from other partitions using the common devID
				for _, device := range devices {
					if device["devID"] == devID {
						if device["computePartitionType"].(string) != "" && device["memoryPartitionType"].(string) != "" {
							computePartitionType = device["computePartitionType"].(string)
							memoryPartitionType = device["memoryPartitionType"].(string)
							numaNode = device["numaNode"].(int)
							break
						}
					}
				}
				if id, exists := renderNodeIds[renderD]; exists {
					nodeId = id
				}
			}
		}
		// This is needed because some of the visible renderD are actually not valid
		// Their validity depends on topology information from KFD

		if _, exists := renderDevIds[renderD]; !exists {
			continue
		}
		if numaNode == -1 {
			continue
		}
		devices[filepath.Base(path)] = map[string]interface{}{"card": card, "renderD": renderD, "devID": devID, "computePartitionType": computePartitionType, "memoryPartitionType": memoryPartitionType, "numaNode": numaNode, "nodeId": nodeId}
	}
	glog.Infof("Devices map: %v", devices)
	return devices
}

func UniquePartitionConfigCount(devices map[string]map[string]interface{}) map[string]int {
	partitionCountMap := make(map[string]int)

	for _, device := range devices {
		computePartitionType := device["computePartitionType"].(string)
		memoryPartitionType := device["memoryPartitionType"].(string)

		if computePartitionType != "" && memoryPartitionType != "" {
			overallPartition := computePartitionType + "_" + memoryPartitionType
			partitionCountMap[overallPartition]++
		}

	}

	glog.Infof("Partition counts: %v", partitionCountMap)
	return partitionCountMap
}

func IsHomogeneous() bool {
	dcus := GetHYDCUs()
	partitionCountMap := UniquePartitionConfigCount(dcus)

	return len(partitionCountMap) <= 1
}
