trace_model.py 263 Bytes
Newer Older
1
2
3
4
5
6
7
8
import os.path as osp

import torch
import torchvision

HERE = osp.dirname(osp.abspath(__file__))
ASSETS = osp.dirname(osp.dirname(HERE))

9
model = torchvision.models.resnet18()
10
11
12
13
model.eval()

traced_model = torch.jit.script(model)
traced_model.save("resnet18.pt")