pth_to_onnx.py 3.54 KB
Newer Older
liangjing's avatar
liangjing committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python3
import argparse

import torch
import torch.onnx
import torchvision
from torch.autograd import Variable

from model.retinanet import retinanet_from_backbone

def parse_args(add_help=True):
    parser = argparse.ArgumentParser(description='Convert PyTorch detection file to onnx format', add_help=add_help)

    parser.add_argument('--input', required=True, help='input pth file')
    parser.add_argument('--output', default=None, help='output onnx file')

    parser.add_argument('--backbone', default='resnext50_32x4d',
                        choices=['resnet50', 'resnext50_32x4d', 'resnet101', 'resnext101_32x8d'],
                        help='The model backbone')
    parser.add_argument('--num-classes', default=264, type=int,
                        help='Number of detection classes')
    parser.add_argument('--trainable-backbone-layers', default=3, type=int,
                        help='number of trainable layers of backbone')

    parser.add_argument('--image-size', default=None, nargs=2, type=int,
                        help='Image size for training. If not set then will be dynamic')
    parser.add_argument('--batch-size', default=None, type=int,
                        help='input batch size. if not set then will be dynamic')
    parser.add_argument('--data-layout', default="channels_first", choices=['channels_first', 'channels_last'],
                        help="Model data layout")
    parser.add_argument('--device', default='cuda', help='device')

    args = parser.parse_args()

    args.output = args.output or ('retinanet_'+args.backbone+'.onnx')
    return args

def main(args):
    batch_size = args.batch_size or 1
    image_size = args.image_size or [800, 800]

    print("Creating model")
    model = retinanet_from_backbone(backbone=args.backbone,
                                    num_classes=args.num_classes,
                                    image_size=image_size,
                                    data_layout=args.data_layout,
                                    pretrained=False,
                                    trainable_backbone_layers=args.trainable_backbone_layers)
    device = torch.device(args.device)
    model.to(device)

    print("Loading model")
    checkpoint = torch.load(args.input)
    model.load_state_dict(checkpoint['model'])

    print("Creating input tensor")
    rand = torch.randn(batch_size, 3, image_size[0], image_size[1],
                       device=device,
                       requires_grad=False,
                       dtype=torch.float)
    inputs = torch.autograd.Variable(rand)
    # Output dynamic axes
    dynamic_axes = {
        'boxes': {0 : 'num_detections'},
        'scores': {0 : 'num_detections'},
        'labels': {0 : 'num_detections'},
    }
    # Input dynamic axes
    if (args.batch_size is None) or (args.image_size is None):
        dynamic_axes['images'] = {}
        if args.batch_size is None:
            dynamic_axes['images'][0]: 'batch_size'
        if args.image_size is None:
            dynamic_axes['images'][2] = 'width'
            dynamic_axes['images'][3] = 'height'


    print("Exporting the model")
    model.eval()
    torch.onnx.export(model,
                      inputs,
                      args.output,
                      export_params=True,
                      opset_version=13,
                      do_constant_folding=False,
                      input_names=['images'],
                      output_names=['boxes', 'scores', 'labels'],
                      dynamic_axes=dynamic_axes)


if __name__ == "__main__":
    args = parse_args()
    main(args)