trace_model.py 338 Bytes
Newer Older
1
2
3
4
5
6
7
8
import os.path as osp

import torch
import torchvision

HERE = osp.dirname(osp.abspath(__file__))
ASSETS = osp.dirname(osp.dirname(HERE))

9
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None)
10
11
12
13
model.eval()

traced_model = torch.jit.script(model)
traced_model.save("fasterrcnn_resnet50_fpn.pt")