import onnx
import tvm
from PIL import Image
import cv2
from tvm import relay
import numpy as np
from yolov5s_pred_utils import non_max_suppression

# onnx_model = onnx.load('model-zoo/googlenet.onnx')
onnx_model = onnx.load('./yolov5s.onnx')
img = Image.open('./cow.jpg').resize((640,640))
img = np.array(img).transpose((2, 0, 1)).astype('float32')
img = img/255.0
x = img[np.newaxis, :]

#img_data = np.random.rand(1,3,224,224).astype("float32")/255
#target = "rocm"
# target = "llvm"
dev = tvm.rocm(0)
# dev = tvm.cpu(0)
#target = "rocm -libs=miopen"
target = "rocm -libs=miopen,rocblas"
input_name = onnx_model.graph.input[0].name
print(input_name)
shape_dict = {input_name:x.shape}
print('shape_dict', shape_dict)

mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, dtype='float32')
# with relay.build_config(opt_level=2):
    # graph, lib, params = relay.build_module.build(mod, target=target, params=params)
dtype = 'float32'
from tvm.contrib import graph_runtime
from tvm.contrib import graph_executor

with tvm.transform.PassContext(opt_level=1):
	lib = relay.build(mod, target=target, params=params)
	# executor = relay.build_module.create_executor("graph", mod, dev, target, params).evaluate()
# output = executor(tvm.nd.array(x.astype(dtype)))

m = graph_executor.GraphModule(lib["default"](dev))
m.set_input(input_name,tvm.nd.array(x.astype(dtype)))
m.run()


'''
print('output model files')
libpath = 'out/googlenet.so'
lib.export_library(libpath)

graph_json_path = 'out/googlenet.json'
with open(graph_json_path, 'w')as f:
	f.write(graph)

params_path = 'out/googlenet.params'
with open(params_path, 'wb')as f:
	f.write(relay.save_param_dict(params))


load_json = open(graph_json_path).read()
load_lib = tvm.runtime.load_module(libpath)
load_params = bytearray(open(params_path, 'rb').read())
ctx = tvm.rocm()
module = graph_runtime.create(load_json,load_lib,ctx)
module.load_params(load_params)
module.run()
'''
# output = module.get_output(0).asnumpy()
output = m.get_output(0).asnumpy()
pred = non_max_suppression(output, conf_thres=0.1, iou_thres=0.50, classes=None, agnostic=False, multi_label=False, max_det=1000)
print(pred)
print(np.max(output,axis=1))
print(np.argmax(output,axis=1))
