"""Test script for torch module"""
import torch
import time
import tvm
from tvm.contrib.torch import compile
import torch.onnx as onnx
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.rand([1, 3, 224, 224]).to(device)
#checkpoint = torch.load("./yolov8n.pt")
#model.load(weights='yolov8n.pt')
#torch.onnx.export(model.predictor.model, x, "./yolov8n.onnx")
#model.export(format='onnx',imgsz=[384, 640], device="cuda")
model_jit = torch.jit.load("./model.pt")


option = {
    "input_infos": [
        ("x", (1, 3, 224, 224)),
    ],
    "default_dtype": "float32",
    "export_dir": "pytorch_compiled",
    "num_outputs": 1,
    "tuning_n_trials": 0,  # set zero to skip tuning
    "tuning_log_file": "tuning.log",
    "target": "rocm --libs=miopen,rocblas",
    "device": tvm.rocm(),
}

pytorch_tvm_module = compile(model_jit, option)

print("Run PyTorch...")
for i in range(1):
    t = time.time()
#    module = torch.jit.load("./model_tvm.pt");
#    outputs=module([x])
    outputs = pytorch_tvm_module.forward([x])
    torch.cuda.synchronize()
    print(1000 * (time.time() - t))
print(outputs[0].shape)

