trace_model.py 279 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import os.path as osp

import torch
import torchvision

HERE = osp.dirname(osp.abspath(__file__))
ASSETS = osp.dirname(osp.dirname(HERE))

model = torchvision.models.resnet18(pretrained=False)
model.eval()

traced_model = torch.jit.script(model)
traced_model.save("resnet18.pt")