import onnxruntime as ort 
from PIL import Image, ImageDraw
from torchvision.transforms import ToTensor
import torch
import time
# print(onnx.helper.printable_graph(mm.graph))

im = Image.open('datasets/000000033109.jpg').convert('RGB') # TODO:修改推理图片路径
im = im.resize((640, 640))
im_data = ToTensor()(im)[None]
print(im_data.shape)
size = torch.tensor([[640, 640]])
sess = ort.InferenceSession("model/onnx/rtdetr_r101vd_6x_coco.onnx")  # TODO:修改onnx模型路径
start=time.time()
output = sess.run(
    # output_names=['labels', 'boxes', 'scores'],
    output_names=None,
    input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()}
)
end=time.time()
print(end-start,1.0/(end-start))

# print(type(output))
# print([out.shape for out in output])

labels, boxes, scores = output
print("labels shape = ",labels)
print("boxes shape = ",boxes)
print("scores shape = ",scores)
draw = ImageDraw.Draw(im)
thrh = 0.6

for i in range(im_data.shape[0]):

    scr = scores[i]
    lab = labels[i][scr > thrh]
    box = boxes[i][scr > thrh]

    print(i, sum(scr > thrh))

    for b in box:
        draw.rectangle(list(b), outline='red',)
        draw.text((b[0], b[1]), text=str(lab[i]), fill='blue', )

im.save('result/test.jpg')  # TODO:修改推理结果存储路径