pytorch2onnx.py 4.54 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import os
import sys
import warnings
from typing import Dict, List, Tuple

import numpy as np
import onnx
import torch
from torch import Tensor

sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from util import utils
from util.lazy_load import Config


class ONNXDetector:
    def __init__(self, onnx_file):
        import onnxruntime
        self.session = onnxruntime.InferenceSession(
            onnx_file, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
        )
        self.io_binding = self.session.io_binding()
        self.is_cuda_available = onnxruntime.get_device() == "GPU"

    def __call__(self, images: List[Tensor], targets: List[Dict] = None):
        if targets is not None:
            warnings.warn("Currently ONNXDetector only support inference, targets will be ignored")
        assert len(images) == 1, "Currently ONNXDetector only support batch_size=1 for inference"
        assert images[0].ndim == 3, "Each image must be with three dimensions of C, H, W"
        if isinstance(images, (List, Tuple)):
            images = torch.stack(images)

        # set io binding for inputs/outputs
        device_type = images.device.type if self.is_cuda_available else "cpu"
        if not self.is_cuda_available:
            images = images.cpu()
        self.io_binding.bind_input(
            name="images",
            device_type=device_type,
            device_id=0,
            element_type=np.float32,
            shape=images.shape,
            buffer_ptr=images.data_ptr(),
        )
        for output in self.session.get_outputs():
            self.io_binding.bind_output(output.name)

        # run session to get outputs
        self.session.run_with_iobinding(self.io_binding)
        detections = self.io_binding.copy_outputs_to_cpu()
        return detections


def parse_args():
    parser = argparse.ArgumentParser(description="Convert a pytorch model to ONNX model")

    # model parameters
    parser.add_argument("--model-config", type=str, default=None)
    parser.add_argument("--checkpoint", type=str, default=None)
    parser.add_argument("--shape", type=int, nargs="+", default=(1333, 800))

    # save parameters
    parser.add_argument("--save-file", type=str, required=True)

    # onnx parameters
    parser.add_argument("--opset-version", type=int, default=17)
    parser.add_argument("--dynamic-export", type=bool, default=True)
    parser.add_argument("--simplify", action="store_true")
    parser.add_argument("--verify", action="store_true")

    args = parser.parse_args()
    return args


def pytorch2onnx():
    # get args from parser
    args = parse_args()
    model = Config(args.model_config).model
    model.eval()
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint, map_location="cpu")
        utils.load_state_dict(model, checkpoint["model"] if "model" in checkpoint else checkpoint)
    image = torch.randn(1, 3, args.shape[0], args.shape[1])

    if args.dynamic_export:
        dynamic_axes = {
            "images": {
                0: "batch",
                2: "height",
                3: "width",
            },
        }
    else:
        dynamic_axes = None
    torch.onnx.export(
        model=model,
        args=image,
        f=args.save_file,
        input_names=["images"],
        output_names=["scores", "labels", "boxes"],
        dynamic_axes=dynamic_axes,
        opset_version=args.opset_version,
    )

    if args.simplify:
        import onnxsim
        model_ops, check_ok = onnxsim.simplify(args.save_file)
        if check_ok:
            onnx.save(model_ops, args.save_file)
            print(f"Successfully simplified ONNX model: {args.save_file}")
        else:
            warnings.warn("Failed to simplify ONNX model.")
    print(f"Successfully exported ONNX model: {args.save_file}")

    if args.verify:
        # check by onnx
        onnx_model = onnx.load(args.save_file)
        onnx.checker.check_model(onnx_model)

        # check onnx results and pytorch results
        onnx_model = ONNXDetector(args.save_file)
        onnx_results = onnx_model(image)
        pytorch_results = list(model(image)[0].values())
        err_msg = "The numerical values are different between Pytorch and ONNX"
        err_msg += "But it does not necessarily mean the exported ONNX is problematic."
        for onnx_res, pytorch_res in zip(onnx_results, pytorch_results):
            np.testing.assert_allclose(onnx_res, pytorch_res, rtol=1e-3, atol=1e-5, err_msg=err_msg)
        print("The numerical values are the same between Pytorch and ONNX")


if __name__ == "__main__":
    pytorch2onnx()