onnx2tensorrt.py 4.83 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
unknown's avatar
unknown committed
2
3
4
import argparse
import os
import os.path as osp
5
import warnings
unknown's avatar
unknown committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

import numpy as np


def get_GiB(x: int):
    """return x GiB."""
    return x * (1 << 30)


def onnx2tensorrt(onnx_file,
                  trt_file,
                  input_shape,
                  max_batch_size,
                  fp16_mode=False,
                  verify=False,
                  workspace_size=1):
    """Create tensorrt engine from onnx model.

    Args:
        onnx_file (str): Filename of the input ONNX model file.
        trt_file (str): Filename of the output TensorRT engine file.
        input_shape (list[int]): Input shape of the model.
            eg [1, 3, 224, 224].
        max_batch_size (int): Max batch size of the model.
        verify (bool, optional): Whether to verify the converted model.
            Defaults to False.
32
        workspace_size (int, optional): Maximum workspace of GPU.
unknown's avatar
unknown committed
33
34
35
36
37
38
            Defaults to 1.
    """
    import onnx
    from mmcv.tensorrt import TRTWraper, onnx2trt, save_trt_engine

    onnx_model = onnx.load(onnx_file)
39
    # create trt engine and wrapper
unknown's avatar
unknown committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    assert max_batch_size >= 1
    max_shape = [max_batch_size] + list(input_shape[1:])
    opt_shape_dict = {'input': [input_shape, input_shape, max_shape]}
    max_workspace_size = get_GiB(workspace_size)
    trt_engine = onnx2trt(
        onnx_model,
        opt_shape_dict,
        fp16_mode=fp16_mode,
        max_workspace_size=max_workspace_size)
    save_dir, _ = osp.split(trt_file)
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
    save_trt_engine(trt_engine, trt_file)
    print(f'Successfully created TensorRT engine: {trt_file}')

    if verify:
        import onnxruntime as ort
57
        import torch
unknown's avatar
unknown committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

        input_img = torch.randn(*input_shape)
        input_img_cpu = input_img.detach().cpu().numpy()
        input_img_cuda = input_img.cuda()

        # Get results from ONNXRuntime
        session_options = ort.SessionOptions()
        sess = ort.InferenceSession(onnx_file, session_options)

        # get input and output names
        input_names = [_.name for _ in sess.get_inputs()]
        output_names = [_.name for _ in sess.get_outputs()]

        onnx_outputs = sess.run(None, {
            input_names[0]: input_img_cpu,
        })

        # Get results from TensorRT
        trt_model = TRTWraper(trt_file, input_names, output_names)
        with torch.no_grad():
            trt_outputs = trt_model({input_names[0]: input_img_cuda})
        trt_outputs = [
            trt_outputs[_].detach().cpu().numpy() for _ in output_names
        ]

        # Compare results
        np.testing.assert_allclose(
            onnx_outputs[0], trt_outputs[0], rtol=1e-05, atol=1e-05)
        print('The numerical values are the same ' +
              'between ONNXRuntime and TensorRT')


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert MMClassification models from ONNX to TensorRT')
    parser.add_argument('model', help='Filename of the input ONNX model')
    parser.add_argument(
        '--trt-file',
        type=str,
        default='tmp.trt',
        help='Filename of the output TensorRT engine')
    parser.add_argument(
        '--verify',
        action='store_true',
        help='Verify the outputs of ONNXRuntime and TensorRT')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[224, 224],
        help='Input size of the model')
    parser.add_argument(
        '--max-batch-size',
        type=int,
        default=1,
        help='Maximum batch size of TensorRT model.')
    parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
    parser.add_argument(
        '--workspace-size',
        type=int,
        default=1,
        help='Max workspace size of GPU in GiB')
    args = parser.parse_args()
    return args


if __name__ == '__main__':

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (1, 3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (1, 3) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    # Create TensorRT engine
    onnx2tensorrt(
        args.model,
        args.trt_file,
        input_shape,
        args.max_batch_size,
        fp16_mode=args.fp16,
        verify=args.verify,
        workspace_size=args.workspace_size)
144
145
146
147
148
149
150
151
152
153
154
155

    # Following strings of text style are from colorama package
    bright_style, reset_style = '\x1b[1m', '\x1b[0m'
    red_text, blue_text = '\x1b[31m', '\x1b[34m'
    white_background = '\x1b[107m'

    msg = white_background + bright_style + red_text
    msg += 'DeprecationWarning: This tool will be deprecated in future. '
    msg += blue_text + 'Welcome to use the unified model deployment toolbox '
    msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
    msg += reset_style
    warnings.warn(msg)