trace_model.py 320 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14

import os.path as osp

import torch
import torchvision

HERE = osp.dirname(osp.abspath(__file__))
ASSETS = osp.dirname(osp.dirname(HERE))

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
model.eval()

traced_model = torch.jit.script(model)
traced_model.save("fasterrcnn_resnet50_fpn.pt")