# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import subprocess
import time
import traceback
from xml.dom import minidom


def collect_gpu_usage(node_id):
    cmd = 'rocm-smi --showuse --showmemuse --showmeminfo vis_vram  --showid --json'.split()
    info = None
    try:
        smi_output = subprocess.check_output(cmd)
        info = parse_nvidia_smi_result(smi_output)
    except Exception:
        traceback.print_exc()
        info = gen_empty_gpu_metric()
    return info


def parse_nvidia_smi_result(smi):
    try:
        output = {}
        gpuList = eval(smi)
        output["Timestamp"] = time.asctime(time.localtime())
        output["gpuCount"] = len(gpuList)
        output["gpuInfos"] = []
        for gpuIndex, gpu in enumerate(gpuList):
            gpuInfo = {}
            gpuInfo['index'] = gpuIndex
            gpuInfo['gpuUtil'] = gpuList[gpu][list(gpuList[gpu].keys())[1]] + "%"
            gpuInfo['gpuMemUtil'] = gpuList[gpu][list(gpuList[gpu].keys())[2]] + "%"
            runningProNumber = 1
            gpuInfo['activeProcessNum'] = runningProNumber

            gpuInfo['gpuType'] = gpuList[gpu][list(gpuList[gpu].keys())[0]]
            gpuInfo['gpuMemTotal'] =  round(float(gpuList[gpu][list(gpuList[gpu].keys())[3]])/1048576, 2)
            gpuInfo['gpuMemUsed'] = round(float(gpuList[gpu][list(gpuList[gpu].keys())[4]])/1048576, 2)
            gpuInfo['gpuMemFree'] =  str(gpuInfo['gpuMemTotal']  - gpuInfo['gpuMemUsed'])
            gpuInfo['gpuMemTotal'] = str(gpuInfo['gpuMemTotal']) + "MB"
            gpuInfo['gpuMemUsed'] = str(gpuInfo['gpuMemUsed']) + "MB"
            gpuInfo['gpuMemFree'] = str(gpuInfo['gpuMemFree']) + "MB"

            output["gpuInfos"].append(gpuInfo)
    except Exception:
        traceback.print_exc()
        output = {}
    return output


def gen_empty_gpu_metric():
    try:
        output = {}
        output["Timestamp"] = time.asctime(time.localtime())
        output["gpuCount"] = 0
        output["gpuInfos"] = []
    except Exception:
        traceback.print_exc()
        output = {}
    return output
