package lib

import (
	"errors"
	"get-container/utils"
	"slices"
	"sync/atomic"
)

var (
	rocmlib_flag              = atomic.Int32{}
	Rocmlib_instance *rocmlib = nil

	ErrNotGetDevNum = errors.New("not get dev num yet")
	ErrNotInit      = errors.New("not init rocm lib yet")
)

// rocmlib rocm库实例
type rocmlib struct {
	status atomic.Int32 // 定义库的状态，0表示没有初始化，1表示初始化了
	carNum int          // 卡数量，初始值为-1，表示没有获取，须调用GetDevNumber方法
}

func GetRocmlib() *rocmlib {
	if rocmlib_flag.Load() == 1 {
		return Rocmlib_instance
	} else {
		result := &rocmlib{
			status: atomic.Int32{},
			carNum: -1,
		}
		result.status.Store(0)
		rocmlib_flag.Store(1)
		Rocmlib_instance = result
		return result
	}
}

func (r *rocmlib) Init() (bool, error) {
	if r.status.CompareAndSwap(0, 1) {
		result := RSMI_init()
		if result == nil {
			return true, nil
		} else {
			r.status.Store(0)
			return false, result
		}
	} else {
		return true, nil
	}
}

func (r *rocmlib) Shutdown() (bool, error) {
	if r.status.CompareAndSwap(1, 0) {
		result := RSMI_shut_down()
		if result == nil {
			r.carNum = -1
			return true, nil
		} else {
			r.status.Store(1)
			return false, result
		}
	} else {
		return true, nil
	}
}

func (r *rocmlib) IsInited() bool {
	return r.status.Load() == 1
}

func (r *rocmlib) GetDevNumber() (int, error) {
	if !r.status.CompareAndSwap(1, 1) {
		return 0, ErrNotInit
	}
	num, err := RSMI_num_monitor_devices()
	if err != nil {
		return 0, err
	}
	r.carNum = int(num)
	return r.carNum, nil
}

func (r *rocmlib) GetDevName() (map[int]string, error) {
	num := r.carNum
	if !r.status.CompareAndSwap(1, 1) {
		return nil, ErrNotInit
	}
	if num == -1 {
		return nil, ErrNotGetDevNum
	}
	result := make(map[int]string)
	if num == 0 {
		return result, nil
	}
	for i := range num {
		name, err := RSMI_dev_subsystem_name_get(uint32(i))
		if err != nil {
			result[i] = "unknow"
		} else {
			result[i] = name
		}
	}
	return result, nil
}

func (r *rocmlib) GetPerfLevel() (map[int]string, error) {
	num := r.carNum
	if !r.status.CompareAndSwap(1, 1) {
		return nil, ErrNotInit
	}
	result := make(map[int]string)
	if num == -1 {
		return nil, ErrNotGetDevNum
	}
	if num == 0 {
		return result, nil
	}
	for i := range num {
		level, err := RSMI_dev_perf_level_get(uint32(i))
		if err != nil {
			result[i] = "unknow"
		} else {
			levelName, have := PerfNameMap[level]
			if have {
				result[i] = levelName
			} else {
				result[i] = "unknow"
			}
		}
	}
	return result, nil
}

func (r *rocmlib) GetFanSpeed() (map[int]int64, error) {
	num := r.carNum
	if !r.status.CompareAndSwap(1, 1) {
		return nil, ErrNotInit
	}
	if num == -1 {
		return nil, ErrNotGetDevNum
	}
	result := make(map[int]int64)
	if num == 0 {
		return result, nil
	}
	for i := range num {
		rpm, err := RSMI_dev_fan_rpms_get(uint32(i))
		if err != nil {
			result[i] = 0
		} else {
			result[i] = rpm
		}
	}
	return result, nil
}

func (r *rocmlib) GetTemp() (map[int]int64, error) {
	num := r.carNum
	if !r.status.CompareAndSwap(1, 1) {
		return nil, ErrNotInit
	}
	if num == -1 {
		return nil, ErrNotGetDevNum
	}
	result := make(map[int]int64)
	if num == 0 {
		return result, nil
	}
	for i := range num {
		temp, err := RSMI_dev_temp_metric_get(uint32(i))
		if err != nil {
			result[i] = 0
		} else {
			result[i] = temp
		}
	}
	return result, nil
}

// GetPowerAvg 获取所有卡的平均功率，单位是毫瓦
func (r *rocmlib) GetPowerAvg() (map[int]uint64, error) {
	num := r.carNum
	if !r.status.CompareAndSwap(1, 1) {
		return nil, ErrNotInit
	}
	if num == -1 {
		return nil, ErrNotGetDevNum
	}
	result := make(map[int]uint64)
	if num == 0 {
		return result, nil
	}
	for i := range num {
		pwr, err := RSMI_dev_power_ave_get(uint32(i))
		if err != nil {
			result[i] = 0
		} else {
			result[i] = pwr
		}
	}
	return result, nil
}

// GetPowerAvg 获取所有卡的功率墙，单位是毫瓦
func (r *rocmlib) GetPowerCap() (map[int]uint64, error) {
	num := r.carNum
	if !r.status.CompareAndSwap(1, 1) {
		return nil, ErrNotInit
	}
	if num == -1 {
		return nil, ErrNotGetDevNum
	}
	result := make(map[int]uint64)
	if num == 0 {
		return result, nil
	}
	for i := range num {
		pwr, err := RSMI_dev_power_cap_get(uint32(i))
		if err != nil {
			result[i] = 0
		} else {
			result[i] = pwr
		}
	}
	return result, nil
}

func (r *rocmlib) GetPCIBusId() (map[int]string, error) {
	num := r.carNum
	if !r.status.CompareAndSwap(1, 1) {
		return nil, ErrNotInit
	}
	if num == -1 {
		return nil, ErrNotGetDevNum
	}
	result := make(map[int]string)
	if num == 0 {
		return result, nil
	}
	for i := range num {
		pci, err := RSMI_dev_pci_id_get(uint32(i))
		if err != nil {
			result[i] = "unknow"
		} else {
			result[i] = utils.PCIBus(pci, 0)
		}
	}
	return result, nil
}

func (r *rocmlib) GetMemTotal() (map[int]uint64, error) {
	num := r.carNum
	if !r.status.CompareAndSwap(1, 1) {
		return nil, ErrNotInit
	}
	if num == -1 {
		return nil, ErrNotGetDevNum
	}
	result := make(map[int]uint64)
	if num == 0 {
		return result, nil
	}
	for i := range num {
		mem, err := RSMI_dev_memory_total_get(uint32(i))
		if err != nil {
			result[i] = 0
		} else {
			result[i] = mem
		}
	}
	return result, nil
}

func (r *rocmlib) GetMemUsed() (map[int]uint64, error) {
	num := r.carNum
	if !r.status.CompareAndSwap(1, 1) {
		return nil, ErrNotInit
	}
	if num == -1 {
		return nil, ErrNotGetDevNum
	}
	result := make(map[int]uint64)
	if num == 0 {
		return result, nil
	}
	for i := range num {
		mem, err := RSMI_dev_memory_usage_get(uint32(i))
		if err != nil {
			result[i] = 0
		} else {
			result[i] = mem
		}
	}
	return result, nil
}

func (r *rocmlib) GetBusyPercent() (map[int]uint32, error) {
	num := r.carNum
	if !r.status.CompareAndSwap(1, 1) {
		return nil, ErrNotInit
	}
	if num == -1 {
		return nil, ErrNotGetDevNum
	}
	result := make(map[int]uint32)
	if num == 0 {
		return result, nil
	}
	for i := range num {
		mem, err := RSMI_dev_busy_percent_get(uint32(i))
		if err != nil {
			result[i] = 0
		} else {
			result[i] = mem
		}
	}
	return result, nil
}

func (r *rocmlib) GetSystemDriverVersion() (string, error) {
	if !r.status.CompareAndSwap(1, 1) {
		return "", ErrNotInit
	}
	return RSMI_version_str_get()
}

func (r *rocmlib) GetProcessInfo() ([]RSMIProcessInfo, error) {
	num := r.carNum
	if !r.status.CompareAndSwap(1, 1) {
		return nil, ErrNotInit
	}
	if num == -1 {
		return nil, ErrNotGetDevNum
	}
	result := make([]RSMIProcessInfo, 0)
	if num == 0 {
		return result, nil
	}
	result, err := RSMI_compute_process_info_get()
	if err != nil {
		return result, err
	}
	if len(result) == 0 {
		return result, nil
	}
	l := len(result)
	for i := range l {
		indexs, err := RSMI_compute_process_gpus_get(result[i].Pid)
		if err != nil {
			return nil, err
		}
		result[i].UsedGPUIndex = indexs
	}
	s := slices.DeleteFunc(result, func(info RSMIProcessInfo) bool {
		return len(info.UsedGPUIndex) == 0
	})
	return s, nil
}
