script_model.py 326 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
import torch
from torchvision import models

for model, name in (
    (models.resnet18(weights=None), "resnet18"),
    (models.detection.fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None), "fasterrcnn_resnet50_fpn"),
):
    model.eval()
    traced_model = torch.jit.script(model)
    traced_model.save(f"{name}.pt")