package backend

import (
	"context"
	"fmt"
	"get-container/cmd/hytop/lib"
	"get-container/gpu"
	"get-container/utils"
	"math"
	"os"
	"sort"
	"strconv"
	"strings"
	"sync"
	"sync/atomic"
	"time"

	"github.com/samber/mo"
	"github.com/shirou/gopsutil/v3/process"
)

var (
	DCUSInfoMap   *DCUInfoMap       = nil // 记录dcu信息
	DockerPidInfo *DockerProcessMap = nil
	User                            = ""
	HostName                        = ""
	Rocmlib                         = lib.Rocmlib_instance
	DriverVersion                   = ""
	stopCtx       context.Context
	cancelFunc    context.CancelFunc
)

const (
	ENVUser = "USER"
)

// Init 初始化本包的数据
func Init() error {
	HostName, _ = os.Hostname()
	uid := os.Getuid()
	username, seted := os.LookupEnv(ENVUser)
	if !seted {
		User = strconv.Itoa(uid)
	} else {
		User = username
	}
	DCUSInfoMap = &DCUInfoMap{}
	DCUSInfoMap.qinfo = make(map[int]*DCUQuickInfo)
	DCUSInfoMap.sinfo = make(map[int]*DCUSlowInfo)
	Rocmlib = lib.GetRocmlib()
	_, err := Rocmlib.Init()
	if err != nil {
		return err
	}
	DriverVersion, err = Rocmlib.GetSystemDriverVersion()
	if err != nil {
		return err
	}
	_, err = Rocmlib.GetDevNumber()
	DockerPidInfo = &DockerProcessMap{}
	DockerPidInfo.Update()

	if err == nil {
		stopCtx, cancelFunc = context.WithCancel(context.Background())
		go func() {
			ticker := time.NewTicker(time.Second * 8)
			for {
				select {
				case <-ticker.C:
					DockerPidInfo.Update()
					DCUSInfoMap.UpdateSlowInfo()
				case <-stopCtx.Done():
					return
				}
			}
		}()
	}
	return err
}

func Shutdown() {
	Rocmlib.Shutdown()
	cancelFunc()
}

type DCUQuickInfo struct {
	lock             sync.RWMutex
	Id               int                //
	Name             mo.Option[string]  //
	PerformanceLevel mo.Option[string]  //
	Fan              mo.Option[string]  //
	Temp             mo.Option[float32] //
	PwrAvg           mo.Option[float32] //
	PwrCap           mo.Option[float32] //
	BusId            mo.Option[string]  //
	MemTotal         mo.Option[uint64]  //
	MemUsed          mo.Option[uint64]  //
	MemUsedPerent    mo.Option[float32] //
	DCUUTil          mo.Option[float32] //
}

type DCUSlowInfo struct {
	Id      atomic.Int32
	Ecc     atomic.Bool
	PwrMode atomic.Value // string
	Mig     atomic.Bool
}

type DCUInfo struct {
	SlowInfo  DCUSlowInfo
	QuickInfo DCUQuickInfo
}

type DCUInfoMap struct {
	qinfolock sync.RWMutex
	qinfo     map[int]*DCUQuickInfo
	sinfolock sync.RWMutex
	sinfo     map[int]*DCUSlowInfo
}

func (m *DCUInfoMap) UpdateQuickInfo() error {
	num, err := Rocmlib.GetDevNumber()
	if err != nil {
		return err
	}
	names, _ := Rocmlib.GetDevName()
	plevel, _ := Rocmlib.GetPerfLevel()
	fan, _ := Rocmlib.GetFanSpeed()
	temp, _ := Rocmlib.GetTemp()
	pwrAvg, _ := Rocmlib.GetPowerAvg()
	pwrCap, _ := Rocmlib.GetPowerCap()
	busid, _ := Rocmlib.GetPCIBusId()
	memTotal, _ := Rocmlib.GetMemTotal()
	memUsed, _ := Rocmlib.GetMemUsed()
	dcu, _ := Rocmlib.GetBusyPercent()

	set := make(map[int]bool)

	m.qinfolock.Lock()
	defer m.qinfolock.Unlock()

	for i := range num {
		qinfo, have := m.qinfo[i]
		if !have {
			qinfo = &DCUQuickInfo{}
			m.qinfo[i] = qinfo
		}
		qinfo.lock.Lock()
		qinfo.Id = i
		if names != nil {
			qinfo.Name = mo.Some(names[i])
		}
		if plevel != nil {
			qinfo.PerformanceLevel = mo.Some(plevel[i])
		}
		if fan != nil {
			if rpm, have := fan[i]; !have || rpm == 0 {
				qinfo.Fan = mo.Some("N/A")
			} else {
				qinfo.Fan = mo.Some(strconv.Itoa(int(rpm)))
			}
		}
		if temp != nil {
			qinfo.Temp = mo.Some(float32(temp[i]) / 1000)
		}
		if pwrAvg != nil {
			qinfo.PwrAvg = mo.Some(float32(pwrAvg[i]) / 1000000)
		}
		if pwrCap != nil {
			qinfo.PwrCap = mo.Some(float32(pwrCap[i]) / 1000000)
		}
		if busid != nil {
			qinfo.BusId = mo.Some(busid[i])
		}
		if memTotal != nil {
			qinfo.MemTotal = mo.Some(memTotal[i])
		}
		if memUsed != nil {
			qinfo.MemUsed = mo.Some(memUsed[i])
		}

		if qinfo.MemTotal.IsSome() {
			if qinfo.MemTotal.MustGet() == 0 {
				qinfo.MemUsedPerent = mo.Some(float32(0.0))
			} else {
				if qinfo.MemUsed.IsSome() {
					qinfo.MemUsedPerent = mo.Some(float32(qinfo.MemUsed.MustGet()) / float32(qinfo.MemTotal.MustGet()) * 100)
				}
			}
		}
		if dcu != nil {
			qinfo.DCUUTil = mo.Some(float32(dcu[i]))
		}

		qinfo.lock.Unlock()
		set[i] = true
	}
	for k := range m.qinfo {
		a, b := set[k]
		if a && b {
			continue
		}
		delete(m.qinfo, k)
	}
	return nil
}

func (m *DCUInfoMap) UpdateSlowInfo() error {
	num, err := Rocmlib.GetDevNumber()
	if err != nil {
		return err
	}
	ecc, err := gpu.GetEccInfo()
	if err != nil {
		ecc = make(map[int]bool)
		for i := range num {
			ecc[i] = false
		}
		err = nil
		// return err
	}
	rinfo, err := gpu.GetRunningInfo()
	if err != nil {
		return err
	}
	set := make(map[int]bool)
	m.sinfolock.Lock()
	defer m.sinfolock.Unlock()
	for i := range num {
		sinfo, have := m.sinfo[i]
		if !have {
			sinfo = &DCUSlowInfo{}
			m.sinfo[i] = sinfo
		}
		sinfo.Id.Store(int32(i))
		sinfo.Mig.Store(false)
		if r, have := rinfo[i]; have {
			sinfo.PwrMode.Store(r.Mode)
		} else {
			sinfo.PwrMode.Store("Normal")
		}
		e, have := ecc[i]
		if have {
			sinfo.Ecc.Store(e)
		} else {
			sinfo.Ecc.Store(false)
		}
		set[i] = true
	}
	for k := range m.sinfo {
		a, b := set[k]
		if a && b {
			continue
		}
		delete(m.sinfo, k)
	}
	return nil
}

// GetSlowInfo 获取慢更新信息，读取完一定要释放锁
func (m *DCUInfoMap) GetSlowInfo() (map[int]*DCUSlowInfo, sync.Locker) {
	rl := m.sinfolock.RLocker()
	rl.Lock()
	return m.sinfo, rl
}

// GetQuitInfo 获取快更新信息，读取完一定要释放锁
func (m *DCUInfoMap) GetQuitInfo() (map[int]*DCUQuickInfo, sync.Locker) {
	rl := m.qinfolock.RLocker()
	rl.Lock()
	return m.qinfo, rl
}

type DCUProcessInfo struct {
	DCU    int              // 使用的dcu号
	DCUMem utils.MemorySize // 使用的dcu内存容量
	SDMA   int
	Info   ProcessInfo // 通用进程信息
}

type ProcessInfo struct {
	Ppid     int32   // 父进程id
	Pid      int32   // 进程号
	User     string  // 用户名或uid
	CPU      float64 // CPU使用率
	Mem      float32 // 内存使用率
	Time     string  // 占用的CPU时间
	Cmd      string  // 命令
	ContInfo *ContainerInfo
}

func getProcessInfo(pids []int32) map[int32]ProcessInfo {
	result := make(map[int32]ProcessInfo)
	if len(pids) == 0 {
		return result
	}
	dockerInfo, lock := DockerPidInfo.Get()
	defer lock.Unlock()
	for _, pid := range pids {
		p, err := process.NewProcess(int32(pid))
		if err != nil {
			continue
		}
		item := ProcessInfo{Pid: p.Pid}
		item.User, _ = p.Username()
		item.CPU, _ = p.CPUPercent()
		item.Mem, _ = p.MemoryPercent()
		t, err := p.Times()
		if err == nil {
			item.Time = DurationStr(time.Duration(t.System+t.User) * time.Second)
		}
		item.Cmd, _ = p.Cmdline()
		d, have := dockerInfo[item.Pid]
		if have {
			item.ContInfo = d
		}
		result[p.Pid] = item
	}
	return result
}

// GetDCUProcessInfo 返回值的key为dcu index
func (m *DCUInfoMap) GetDCUProcessInfo() map[int][]DCUProcessInfo {
	result := make(map[int][]DCUProcessInfo)
	mem := utils.MemorySize{}
	mem.Unit = utils.Byte
	info, err := Rocmlib.GetProcessInfo()
	if err != nil {
		return result
	}
	pids := make([]int32, 0)
	for _, v := range info {
		pids = append(pids, int32(v.Pid))
	}
	pinfo := getProcessInfo(pids)
	for _, v := range info {
		index := make([]int, 0)
		for _, i := range v.UsedGPUIndex {
			index = append(index, int(i))
		}
		for _, i := range index {
			l, have := result[i]
			if !have {
				result[i] = make([]DCUProcessInfo, 0)
				l = result[i]
			}
			item := DCUProcessInfo{DCU: i}
			item.Info = pinfo[int32(v.Pid)]
			mem.Num = v.VarmUsage
			item.DCUMem = mem
			item.SDMA = int(v.SdmaUsage)
			l = append(l, item)
			result[i] = l
		}
	}
	return result
}

// GetDCUProcessInfo2 返回值的key为dcu index
func (m *DCUInfoMap) GetDCUProcessInfo2() map[int][]DCUProcessInfo {
	result := make(map[int][]DCUProcessInfo)
	info, err := gpu.GetDCUPidInfo()
	if err != nil {
		return result
	}
	pids := make([]int32, 0)
	for _, v := range info {
		pids = append(pids, int32(v.Pid))
	}
	pinfo := getProcessInfo(pids)
	for _, v := range info { // 按照进程遍历
		// 获取进程使用的所有dcu的索引
		index := make([]int, 0)
		for _, i := range v.HCUIndex {
			ii, err := strconv.Atoi(i)
			if err != nil {
				continue
			}
			index = append(index, ii)
		}
		for _, i := range index {
			l, have := result[i]
			if !have {
				result[i] = make([]DCUProcessInfo, 0)
				l = result[i]
			}
			item := DCUProcessInfo{DCU: i}
			item.Info = pinfo[int32(v.Pid)]
			item.DCUMem = v.VRamUsed
			item.SDMA = v.SDMAUsed
			l = append(l, item)

			result[i] = l
		}
	}
	// 每个切片对pid排序
	for k, v := range result {
		sort.SliceStable(v, func(i, j int) bool {
			return v[i].Info.Pid < v[j].Info.Pid
		})
		result[k] = v
	}
	return result
}

// DurationStr 将时间段格式化为 小时:分钟:秒 的格式
func DurationStr(d time.Duration) string {
	h := int(math.Floor(d.Hours()))
	m := int(d.Minutes()) % 60
	s := int(math.Floor(d.Seconds())) % 60
	if h <= 96 {
		return strings.Replace(fmt.Sprintf("%d:%2d:%2d", h, m, s), " ", "0", -1)
	} else {
		return fmt.Sprintf("%.1f days", d.Hours()/24)
	}
}
