package lib

/*
#cgo CFLAGS: -I.
#cgo LDFLAGS: -L/opt/hyhal/lib -lamd_smi

#include <stdlib.h>
#include <string.h>
#include "rocm_smi.h"
*/
import "C"
import (
	"errors"
	"fmt"
	"unsafe"
)

type RSMIResult uint32

const (
	RSMISuccess            RSMIResult = 0
	RSMIInvalidArgs        RSMIResult = 1
	RSMINotSupported       RSMIResult = 2
	RSMIFileError          RSMIResult = 3
	RSMIPermission         RSMIResult = 4
	RSMIOutOfResource      RSMIResult = 5
	RSMIInternalException  RSMIResult = 6
	RSMIInputOutOfBounds   RSMIResult = 7
	RSMIInitError          RSMIResult = 8
	RSMIInitalizationError RSMIResult = RSMIInitError
	RSMINotYetImplemented  RSMIResult = 9
	RSMINotFound           RSMIResult = 10
	RSMIInsufficientSize   RSMIResult = 11
	RSMIInterrupt          RSMIResult = 12
	RSMIUnexpectedSize     RSMIResult = 13
	RSMINoData             RSMIResult = 14
	RSMIUnexpectedData     RSMIResult = 15
	RSMIBusy               RSMIResult = 16
	RSMIRefcountOverflow   RSMIResult = 17
	RSMISettingUnavaliable RSMIResult = 18
	RSMIAmdGPURestartErr   RSMIResult = 19

	RSMIMallocError RSMIResult = 0xFFFFFFFE // malloc执行失败
	RSMIUnknowError RSMIResult = 0xFFFFFFFF
)

var (
	PerfNameMap = map[int32]string{
		0:   "auto",
		1:   "low",
		2:   "high",
		3:   "manual",
		4:   "stable std",
		5:   "stable peak",
		6:   "stable min mclk",
		7:   "stable min sclk",
		8:   "determinism",
		256: "unknow",
	}

	ErrInvalidArgs        = errors.New("rocm invalid args")
	ErrNotSupported       = errors.New("rocm not supported")
	ErrFileError          = errors.New("rocm file error")
	ErrPermission         = errors.New("rocm permission error")
	ErrOutOfResource      = errors.New("rocm out of resource")
	ErrInternalException  = errors.New("rocm internal exception")
	ErrInputOutOfBounds   = errors.New("rocm input out of bounds ")
	ErrInitError          = errors.New("rocm init error")
	ErrInitalizationError = ErrInitError
	ErrNotYetImplemented  = errors.New("rocm not yet implemented")
	ErrNotFound           = errors.New("rocm not found")
	ErrInsufficientSize   = errors.New("rocm insufficient size")
	ErrInterrupt          = errors.New("rocm interrupt")
	ErrUnexpectedSize     = errors.New("rocm unexpected size")
	ErrNoData             = errors.New("rocm no data")
	ErrUnexpectedData     = errors.New("rocm unexpected data")
	ErrBusy               = errors.New("rocm busy")
	ErrRefcountOverflow   = errors.New("rocm refcount overflow")
	ErrSettingUnavaliable = errors.New("rocm setting unavaliable")
	ErrAmdGPURestartErr   = errors.New("rocm amd gpu restart error")
	ErrMallocError        = errors.New("alloc momery error")
	ErrUnknowError        = errors.New("rocm unknow error")

	RocmErrMap = map[RSMIResult]error{
		RSMISuccess:            nil,
		RSMIInvalidArgs:        ErrInvalidArgs,
		RSMINotSupported:       ErrNotSupported,
		RSMIFileError:          ErrFileError,
		RSMIPermission:         ErrPermission,
		RSMIOutOfResource:      ErrOutOfResource,
		RSMIInternalException:  ErrInternalException,
		RSMIInputOutOfBounds:   ErrInputOutOfBounds,
		RSMIInitError:          ErrInitError,
		RSMINotYetImplemented:  ErrNotYetImplemented,
		RSMINotFound:           ErrNotFound,
		RSMIInsufficientSize:   ErrInsufficientSize,
		RSMIInterrupt:          ErrInterrupt,
		RSMIUnexpectedSize:     ErrUnexpectedSize,
		RSMINoData:             ErrNoData,
		RSMIUnexpectedData:     ErrUnexpectedData,
		RSMIBusy:               ErrBusy,
		RSMIRefcountOverflow:   ErrRefcountOverflow,
		RSMISettingUnavaliable: ErrSettingUnavaliable,
		RSMIAmdGPURestartErr:   ErrAmdGPURestartErr,
		RSMIMallocError:        ErrMallocError,
		RSMIUnknowError:        ErrUnknowError,
	}
)

func ToRSMIResult(c C.rsmi_status_t) error {
	e, have := RocmErrMap[RSMIResult(c)]
	if have {
		return e
	}
	return ErrUnknowError
}

// RSMIProcessInfo 对应rsmi_process_info_t
type RSMIProcessInfo struct {
	Pid                   uint32 // Process ID
	ProcessAddressSpaceId uint32 // PASID: (Process Address Space ID)
	VarmUsage             uint64 // VRAM usage
	SdmaUsage             uint64 // SDMA usage in microseconds
	CuOccupancy           uint32 // Compute Unit usage in percent
	UsedGPUIndex          []uint32
}

func (pi *RSMIProcessInfo) FromC(c C.rsmi_process_info_t) {
	pi.Pid = uint32(c.process_id)
	pi.ProcessAddressSpaceId = uint32(c.pasid)
	pi.VarmUsage = uint64(c.vram_usage)
	pi.SdmaUsage = uint64(c.sdma_usage)
	pi.CuOccupancy = uint32(c.cu_occupancy)
}

type RSMIProcessInfoV2 struct {
	Pid           uint32
	VramUsageSize uint64          // VRAM usage size in MiB
	VramUsageRate float32         // VRAM usage rate as a percentage
	UsedGPUs      int             // Used gpu number
	GPUUsage      map[int]float32 // GPU usage rate as a percentage
}

func (pi2 *RSMIProcessInfoV2) FromC(c C.rsmi_process_info_v2_t) {
	pi2.Pid = uint32(c.processId)
	pi2.VramUsageSize = uint64(c.vramUsageSize)
	pi2.VramUsageRate = float32(c.vramUsageRate)
	pi2.UsedGPUs = int(c.usedGpus)
	pi2.GPUUsage = make(map[int]float32)
	for k, v := range c.gpuIndex {
		pi2.GPUUsage[int(v)] = float32(c.gpuUsageRate[k])
	}
}

// RSMI_init 初始化rsmi
func RSMI_init() error {
	return ToRSMIResult(C.rsmi_init(0))
}

// RSMI_shut_down 关闭rsmi
func RSMI_shut_down() error {
	return ToRSMIResult(C.rsmi_shut_down())
}

// RSMI_num_monitor_devices 获取设备数量
func RSMI_num_monitor_devices() (uint32, error) {
	var num C.uint
	ptr := (*C.uint)(unsafe.Pointer(&num))
	res := C.rsmi_num_monitor_devices(ptr)
	return uint32(num), ToRSMIResult(res)
}

// RSMI_version_get 获取当前运行的RSMI版本
func RSMI_version_get() (uint32, uint32, uint32, error) {
	result := (*C.rsmi_version_t)(C.malloc(C.sizeof_rsmi_version_t))
	if unsafe.Pointer(result) != C.NULL {
		defer func() {
			C.free(unsafe.Pointer(result))
		}()
	} else {
		return 0, 0, 0, ErrMallocError
	}
	res := C.rsmi_version_get(result)
	return uint32(result.major), uint32(result.minor), uint32(result.patch), ToRSMIResult(res)
}

// RSMI_version_str_get 获取当前系统的驱动程序版本
func RSMI_version_str_get() (string, error) {
	buff := make([]uint8, 128)
	res := C.rsmi_version_str_get(C.RSMI_SW_COMP_DRIVER, (*C.char)(unsafe.Pointer(&buff[0])), 128)
	return string(buff), ToRSMIResult(res)
}

// RSMI_dev_vbios_version_get 获取VBIOS版本
func RSMI_dev_vbios_version_get(deviceIndex uint32) (string, error) {
	buff := make([]uint8, 128)
	res := C.rsmi_dev_vbios_version_get(C.uint(deviceIndex), (*C.char)(unsafe.Pointer(&buff[0])), 128)
	return string(buff), ToRSMIResult(res)
}

// RSMI_dev_name_get 获取设备名称，只有开头的字符，没有数字，如BW
func RSMI_dev_name_get(deviceIndex uint32) (string, error) {
	buff := make([]uint8, 128)
	res := C.rsmi_dev_name_get(C.uint(deviceIndex), (*C.char)(unsafe.Pointer(&buff[0])), 128)
	return string(buff), ToRSMIResult(res)
}

// RSMI_dev_id_get 获取设备id
func RSMI_dev_id_get(deviceIndex uint32) (uint16, error) {
	var id C.ushort
	res := C.rsmi_dev_id_get(C.uint(deviceIndex), (*C.ushort)(unsafe.Pointer(&id)))
	return uint16(id), ToRSMIResult(res)
}

// RSMI_dev_sku_get 获取设备的sku号
func RSMI_dev_sku_get(deviceIndex uint32) (uint16, error) {
	var sku C.ushort
	res := C.rsmi_dev_sku_get(C.uint(deviceIndex), (*C.ushort)(unsafe.Pointer(&sku)))
	return uint16(sku), ToRSMIResult(res)
}

func RSMI_dev_vendor_id_get(deviceIndex uint32) (uint16, error) {
	var vendor C.ushort
	res := C.rsmi_dev_vendor_id_get(C.uint(deviceIndex), (*C.ushort)(unsafe.Pointer(&vendor)))
	return uint16(vendor), ToRSMIResult(res)
}

func RSMI_dev_brand_get(deviceIndex uint32) (string, error) {
	buff := make([]uint8, 128)
	res := C.rsmi_dev_brand_get(C.uint(deviceIndex), (*C.char)(unsafe.Pointer(&buff[0])), 128)
	return string(buff), ToRSMIResult(res)
}

func RSMI_dev_serial_number_get(deviceIndex uint32) (string, error) {
	buff := make([]uint8, 128)
	res := C.rsmi_dev_serial_number_get(C.uint(deviceIndex), (*C.char)(unsafe.Pointer(&buff[0])), 128)
	return string(buff), ToRSMIResult(res)
}

// RSMI_dev_subsystem_name_get 获取设备全名
func RSMI_dev_subsystem_name_get(deviceIndex uint32) (string, error) {
	buff := make([]uint8, 128)
	res := C.rsmi_dev_subsystem_name_get(C.uint(deviceIndex), (*C.char)(unsafe.Pointer(&buff[0])), 128)
	return string(buff), ToRSMIResult(res)
}

// RSMI_dev_perf_level_get 获取设备运行等级
func RSMI_dev_perf_level_get(deviceIndex uint32) (int32, error) {
	var level C.rsmi_dev_perf_level_t
	res := C.rsmi_dev_perf_level_get(C.uint(deviceIndex), (*C.rsmi_dev_perf_level_t)(unsafe.Pointer(&level)))
	return int32(level), ToRSMIResult(res)
}

// RSMI_compute_process_info_get 获取所有使用显卡的进程信息
func RSMI_compute_process_info_get() ([]RSMIProcessInfo, error) {
	ps := (*C.rsmi_process_info_t)(C.malloc(C.sizeof_rsmi_process_info_t * 128))
	if unsafe.Pointer(ps) != C.NULL {
		defer func() {
			C.free(unsafe.Pointer(ps))
		}()
	} else {
		return nil, ErrMallocError
	}
	var num C.uint = C.uint(128)
	res := C.rsmi_compute_process_info_get(ps, (*C.uint)(unsafe.Pointer(&num)))
	if res != C.RSMI_STATUS_SUCCESS {
		return nil, ToRSMIResult(res)
	}
	psSlice := unsafe.Slice((*C.rsmi_process_info_t)(unsafe.Pointer(ps)), int(num))
	if len(psSlice) == 0 {
		return make([]RSMIProcessInfo, 0), ToRSMIResult(res)
	}
	result := make([]RSMIProcessInfo, int(num))
	for i := range int(num) {
		result[i].FromC(psSlice[i])
	}
	return result, ToRSMIResult(res)
}

// RSMI_compute_process_info_by_pid_get_v2 获取进程的详细信息，注意：不是所有版本的so文件都支持该方法，可能导致进程崩溃
func RSMI_compute_process_info_by_pid_get_v2(pid uint32) (info *RSMIProcessInfoV2, res error) {
	ps2 := (*C.rsmi_process_info_v2_t)(C.malloc(C.sizeof_rsmi_process_info_v2_t))
	if unsafe.Pointer(ps2) != C.NULL {
		defer func() {
			C.free(unsafe.Pointer(ps2))
		}()
	} else {
		info = nil
		res = ErrMallocError
		return
	}
	defer func() {
		if r := recover(); r != nil {
			info = nil
			res = ErrUnknowError
		}
	}()
	r := C.rsmi_compute_process_info_by_pid_get_v2(C.uint(pid), ps2)
	if res != nil {
		info = nil
		res = ToRSMIResult(r)
		return
	}
	result := RSMIProcessInfoV2{}
	result.FromC(*ps2)
	info = &result
	res = nil
	return
}

func RSMI_dev_fan_rpms_get(devIndex uint32) (int64, error) {
	var rpm C.long = C.long(0)
	ptr := (*C.long)(unsafe.Pointer(&rpm))
	res := C.rsmi_dev_fan_rpms_get(C.uint(devIndex), 0, ptr)
	if ToRSMIResult(res) != nil {
		return 0, ToRSMIResult(res)
	}
	return int64(rpm), nil
}

// RSMI_dev_temp_metric_get 获取设备核心温度，结果除以1000就是摄氏度
func RSMI_dev_temp_metric_get(devIndex uint32) (int64, error) {
	var temp C.long = C.long(0)
	ptr := (*C.long)(unsafe.Pointer(&temp))
	res := C.rsmi_dev_temp_metric_get(C.uint(devIndex), C.RSMI_TEMP_TYPE_CORE, C.RSMI_TEMP_CURRENT, ptr)
	return int64(temp), ToRSMIResult(res)
}

// RSMI_dev_power_ave_get 获取设备的平均功耗，单位是微瓦
func RSMI_dev_power_ave_get(devIndex uint32) (uint64, error) {
	var power C.ulong = C.ulong(0)
	ptr := (*C.ulong)(unsafe.Pointer(&power))
	res := C.rsmi_dev_power_ave_get(C.uint(devIndex), 0, ptr)
	return uint64(power), ToRSMIResult(res)
}

// RSMI_dev_power_cap_get 获取设备的功耗墙，单位是微瓦
func RSMI_dev_power_cap_get(devIndex uint32) (uint64, error) {
	var cap C.ulong = C.ulong(0)
	ptr := (*C.ulong)(unsafe.Pointer(&cap))
	res := C.rsmi_dev_power_cap_get(C.uint(devIndex), 0, ptr)
	return uint64(cap), ToRSMIResult(res)
}

// RSMI_dev_pci_id_get 获取设备的PCI id，对于0000:49:00.0，返回的是49:00部分
func RSMI_dev_pci_id_get(devIndex uint32) (uint64, error) {
	var pciid C.ulong = C.ulong(0)
	res := C.rsmi_dev_pci_id_get(C.uint(devIndex), (*C.ulong)(unsafe.Pointer(&pciid)))
	return uint64(pciid), ToRSMIResult(res)
}

func RSMI_dev_memory_total_get(devIndex uint32) (uint64, error) {
	var mem C.ulong = C.ulong(0)
	res := C.rsmi_dev_memory_total_get(C.uint(devIndex), C.RSMI_MEM_TYPE_VRAM, (*C.ulong)(unsafe.Pointer(&mem)))
	return uint64(mem), ToRSMIResult(res)
}

func RSMI_dev_memory_usage_get(devIndex uint32) (uint64, error) {
	var mem C.ulong = C.ulong(0)
	res := C.rsmi_dev_memory_usage_get(C.uint(devIndex), C.RSMI_MEM_TYPE_VRAM, (*C.ulong)(unsafe.Pointer(&mem)))
	return uint64(mem), ToRSMIResult(res)
}

// nvmlDeviceGetMigMode
// dmi/dmi_mig.h

// RSMI_dev_busy_percent_get 获取设备的忙碌百分比
func RSMI_dev_busy_percent_get(devIndex uint32) (uint32, error) {
	var percent C.uint = 0
	res := C.rsmi_dev_busy_percent_get(C.uint(devIndex), (*C.uint)(unsafe.Pointer(&percent)))
	return uint32(percent), ToRSMIResult(res)
}

func RSMI_ecc_enable(devIndex uint32) (bool, error) {
	var blocks C.ulong = 0
	res := C.rsmi_dev_ecc_enabled_get(C.uint(devIndex), (*C.ulong)(unsafe.Pointer(&blocks)))
	if ToRSMIResult(res) != nil {
		return false, ToRSMIResult(res)
	}
	fmt.Printf("%X\n", blocks)
	ss := C.rsmi_gpu_block_t(blocks)
	var stat C.rsmi_ras_err_state_t = 0
	res = C.rsmi_dev_ecc_status_get(C.uint(devIndex), ss, (*C.rsmi_ras_err_state_t)(unsafe.Pointer(&stat)))
	return uint32(stat) == uint32(C.RSMI_RAS_ERR_STATE_ENABLED), ToRSMIResult(res)
}

func RSMI_compute_process_gpus_get(pid uint32) ([]uint32, error) {
	var devIds [32]C.uint
	var devNum C.uint = 32
	ptrDevIds := (*C.uint)(unsafe.Pointer(&devIds[0]))
	res := C.rsmi_compute_process_gpus_get(C.uint(pid), ptrDevIds, (*C.uint)(unsafe.Pointer(&devNum)))
	if ToRSMIResult(res) != nil {
		return nil, ToRSMIResult(res)
	}
	if devNum > 0 {
		result := make([]uint32, devNum)
		for i := range devNum {
			result[i] = uint32(devIds[i])
		}
		return result, nil
	}
	return make([]uint32, 0), nil
}
