package gpu

import (
	"errors"
	"get-container/utils"
	"os/exec"
	"strconv"
	"strings"
)

/*
	从nvidia-smi命令中获取相关信息
*/

const (
	SmiVersionHeader    = "NVIDIA-SMI version"
	NvmlVersionHeader   = "NVML version"
	DriverVersionHeader = "DRIVER version"
	CudaVersionHeader   = "CUDA Version"
)

// NVAppInfo 定义了从nvidia-smi中能直接获取的进程相关信息
type NVAppInfo struct {
	GPUName       string
	GPUBusId      string
	GPUSerial     string
	GPUUUID       string
	Pid           uint64
	ProcessName   string
	UsedGPUMemory utils.MemorySize
}

// GetAppInfo 获取GPU进程相关信息
// nvidia-smi --query-compute-apps=gpu_name,gpu_bus_id,gpu_serial,gpu_uuid,pid,process_name,used_gpu_memory --format=csv,noheader
func GetAppInfo() ([]NVAppInfo, error) {
	output, err := exec.Command("nvidia-smi",
		"--query-compute-apps=gpu_name,gpu_bus_id,gpu_serial,gpu_uuid,pid,process_name,used_gpu_memory",
		"--format=csv,noheader").Output()

	// output为
	// NVIDIA H20, 00000000:0F:00.0, 1321424020484, GPU-f71f52ad-4c29-30dd-7f0f-609de5ff1510, 1272015, /usr/bin/python3, 89976 MiB
	// NVIDIA H20, 00000000:34:00.0, 1321424019230, GPU-e6c3552d-98b5-fd23-a8e0-c1d85fccfdaa, 1272016, /usr/bin/python3, 90072 MiB
	// NVIDIA H20, 00000000:48:00.0, 1321424020685, GPU-53aa03d3-2ac9-1d81-6106-495b68c7315f, 1272017, /usr/bin/python3, 90072 MiB
	// NVIDIA H20, 00000000:5A:00.0, 1321424018547, GPU-9defd340-30ab-9c4b-99aa-818c1169277a, 1272018, /usr/bin/python3, 90072 MiB
	if err != nil {
		return nil, err
	}
	outStr := strings.Trim(string(output), "\n")
	if len(outStr) == 0 {
		return nil, nil
	}
	lines := strings.Split(outStr, "\n")
	if len(lines) == 0 {
		return make([]NVAppInfo, 0), nil
	}
	result := make([]NVAppInfo, 0)
	for _, line := range lines {
		fields := strings.Split(strings.TrimSpace(line), ",")
		if len(fields) < 7 {
			continue
		}
		item := NVAppInfo{}
		item.GPUName = strings.TrimSpace(fields[0])
		item.GPUBusId = strings.TrimSpace(fields[1])
		item.GPUSerial = strings.TrimSpace(fields[2])
		item.GPUUUID = strings.TrimSpace(fields[3])
		item.Pid, err = strconv.ParseUint(strings.TrimSpace(fields[4]), 10, 64)
		if err != nil {
			return nil, err
		}
		item.ProcessName = strings.TrimSpace(fields[5])
		if s, ifErr := utils.ParseMemorySize(fields[6]); ifErr == nil {
			if s == nil {
				return nil, errors.New("parse storage size error")
			}
			item.UsedGPUMemory = *s
		} else {
			return nil, ifErr
		}
		result = append(result, item)
	}
	return result, err
}

// NVVersionInfo 版本信息
type NVVersionInfo struct {
	SMIVersion    string
	NVMLVersion   string
	DriverVersion string
	CUDAVersion   string
}

// GetVersionInfo 获取版本信息
func GetVersionInfo() (*NVVersionInfo, error) {
	output, err := exec.Command("nvidia-smi", "--version").Output()
	if err != nil {
		return nil, err
	}
	if len(output) == 0 {
		return nil, errors.New("nvidia-smi version not found")
	}
	lines := strings.Split(string(output), "\n")
	result := &NVVersionInfo{}
	for _, line := range lines {
		field := strings.SplitN(strings.TrimSpace(line), ":", 2)
		if len(field) != 2 {
			return nil, errors.New("parse nvidia-smi version error")
		}
		switch strings.ToLower(strings.TrimSpace(field[0])) {
		case strings.ToLower(SmiVersionHeader):
			result.SMIVersion = strings.TrimSpace(field[1])
		case strings.ToLower(NvmlVersionHeader):
			result.NVMLVersion = strings.TrimSpace(field[1])
		case strings.ToLower(DriverVersionHeader):
			result.DriverVersion = strings.TrimSpace(field[1])
		case strings.ToLower(CudaVersionHeader):
			result.CUDAVersion = strings.TrimSpace(field[1])
		}
	}
	return result, err
}

// Info GPU基本信息
type Info struct {
	GPUName          string           // name
	DriverVersion    string           // driver_version
	PersistenceMode  bool             // persistence_mode Disabled/Enabled
	FanSpeed         string           // fan.speed
	Temperature      string           // temperature.gpu
	PerformanceState string           // pstate
	BusID            string           // pci.bus_id
	DisplayActive    bool             // display_active Disabled/Enabled
	PowerUsage       string           // power.draw
	PowerCapacity    string           // power.limit
	MemorySize       utils.MemorySize // memory.total
	MemoryUsage      utils.MemorySize // memory.used
	VBIOSVersion     string           // vbios_version
	MIGMode          bool             // mig.mode.current Disabled/Enabled
}

// GetGPUInfo 获取GPU信息
func GetGPUInfo() ([]Info, error) {
	output, err := exec.Command("nvidia-smi", "--format=csv,noheader",
		"--query-gpu=name,driver_version,persistence_mode,fan.speed,temperature.gpu,pstate,pci.bus_id,display_active,power.draw,power.limit,memory.total,memory.used,vbios_version,mig.mode.current").Output()
	if err != nil {
		return nil, err
	}
	outStr := strings.Trim(string(output), "\n")
	if len(outStr) == 0 {
		return make([]Info, 0), nil
	}
	lines := strings.Split(outStr, "\n")
	// NVIDIA H20, 570.86.10, Enabled, [N/A], 34, P0, 00000000:0F:00.0, Disabled, 123.33 W, 500.00 W, 97871 MiB, 89986 MiB, 96.00.99.00.1D, Disabled
	// NVIDIA H20, 570.86.10, Enabled, [N/A], 31, P0, 00000000:34:00.0, Disabled, 115.92 W, 500.00 W, 97871 MiB, 90082 MiB, 96.00.99.00.1D, Disabled
	// NVIDIA H20, 570.86.10, Enabled, [N/A], 33, P0, 00000000:48:00.0, Disabled, 118.64 W, 500.00 W, 97871 MiB, 90082 MiB, 96.00.99.00.1D, Disabled
	// NVIDIA H20, 570.86.10, Enabled, [N/A], 29, P0, 00000000:5A:00.0, Disabled, 113.40 W, 500.00 W, 97871 MiB, 90082 MiB, 96.00.99.00.1D, Disabled
	result := make([]Info, 0)
	for _, line := range lines {
		fields := strings.Split(strings.TrimSpace(line), ",")
		if len(fields) != 14 {
			continue
		}
		item := Info{}
		item.GPUName = strings.TrimSpace(fields[0])
		item.DriverVersion = strings.TrimSpace(fields[1])
		item.PersistenceMode = strings.TrimSpace(fields[2]) == "Enabled"
		item.FanSpeed = strings.TrimSpace(fields[3])
		item.Temperature = strings.TrimSpace(fields[4])
		item.PerformanceState = strings.TrimSpace(fields[5])
		item.BusID = strings.TrimSpace(fields[6])
		item.DisplayActive = strings.TrimSpace(fields[7]) == "Enabled"
		item.PowerUsage = strings.TrimSpace(fields[8])
		item.PowerCapacity = strings.TrimSpace(fields[9])
		if s, innerErr := utils.ParseMemorySize(strings.TrimSpace(fields[10])); innerErr == nil {
			if s == nil {
				return nil, errors.New("parse storage size error")
			}
			item.MemorySize = *s
		} else {
			return nil, innerErr
		}
		if s, innerErr := utils.ParseMemorySize(strings.TrimSpace(fields[11])); innerErr == nil {
			if s == nil {
				return nil, errors.New("parse storage size error")
			}
			item.MemoryUsage = *s
		} else {
			return nil, innerErr
		}
		item.VBIOSVersion = strings.TrimSpace(fields[12])
		item.MIGMode = strings.TrimSpace(fields[13]) == "Enabled"
		result = append(result, item)
	}
	return result, nil
}
