export.py 3.23 KB
Newer Older
PRC-Huang's avatar
PRC-Huang committed
1
2
3
4
5
6
# --------------------------------------------------------
# InternImage
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

zhe chen's avatar
zhe chen committed
7
import argparse
PRC-Huang's avatar
PRC-Huang committed
8
9
10
11
12
13
import os
import time

import torch
from config import get_config
from models import build_model
zhe chen's avatar
zhe chen committed
14
15
from tqdm import tqdm

PRC-Huang's avatar
PRC-Huang committed
16
17
18
19
20
21

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str,
                        default='internimage_t_1k_224')
    parser.add_argument('--ckpt_dir', type=str,
zhe chen's avatar
zhe chen committed
22
                        default='pretrained/')
PRC-Huang's avatar
PRC-Huang committed
23
24
25
26
27
28
29
30
31
32
33
    parser.add_argument('--onnx', default=False, action='store_true')
    parser.add_argument('--trt', default=False, action='store_true')

    args = parser.parse_args()
    args.cfg = os.path.join('./configs', f'{args.model_name}.yaml')
    args.ckpt = os.path.join(args.ckpt_dir, f'{args.model_name}.pth')
    args.size = int(args.model_name.split('.')[0].split('_')[-1])

    cfg = get_config(args)
    return args, cfg

zhe chen's avatar
zhe chen committed
34

PRC-Huang's avatar
PRC-Huang committed
35
36
37
38
39
40
41
def get_model(args, cfg):
    model = build_model(cfg)
    ckpt = torch.load(args.ckpt, map_location='cpu')['model']

    model.load_state_dict(ckpt)
    return model

zhe chen's avatar
zhe chen committed
42

PRC-Huang's avatar
PRC-Huang committed
43
44
45
46
47
48
49
50
51
52
53
54
def speed_test(model, input):
    # warmup
    for _ in tqdm(range(100)):
        _ = model(input)

    # speed test
    torch.cuda.synchronize()
    start = time.time()
    for _ in tqdm(range(100)):
        _ = model(input)
    end = time.time()
    th = 100 / (end - start)
zhe chen's avatar
zhe chen committed
55
56
    print(f'using time: {end - start}, throughput {th}')

PRC-Huang's avatar
PRC-Huang committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

def torch2onnx(args, cfg):
    model = get_model(args, cfg).cuda()

    # speed_test(model)

    onnx_name = f'{args.model_name}.onnx'
    torch.onnx.export(model,
                      torch.rand(1, 3, args.size, args.size).cuda(),
                      onnx_name,
                      input_names=['input'],
                      output_names=['output'])

    return model

zhe chen's avatar
zhe chen committed
72

PRC-Huang's avatar
PRC-Huang committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def onnx2trt(args):
    from mmdeploy.backend.tensorrt import from_onnx

    onnx_name = f'{args.model_name}.onnx'
    from_onnx(
        onnx_name,
        args.model_name,
        dict(
            input=dict(
                min_shape=[1, 3, args.size, args.size],
                opt_shape=[1, 3, args.size, args.size],
                max_shape=[1, 3, args.size, args.size],
            )
        ),
        max_workspace_size=2**30,
    )

zhe chen's avatar
zhe chen committed
90

PRC-Huang's avatar
PRC-Huang committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
def check(args, cfg):
    from mmdeploy.backend.tensorrt.wrapper import TRTWrapper

    model = get_model(args, cfg).cuda()
    model.eval()
    trt_model = TRTWrapper(f'{args.model_name}.engine',
                           ['output'])

    x = torch.randn(1, 3, args.size, args.size).cuda()

    torch_out = model(x)
    trt_out = trt_model(dict(input=x))['output']

    print('torch out shape:', torch_out.shape)
    print('trt out shape:', trt_out.shape)

    print('max delta:', (torch_out - trt_out).abs().max())
    print('mean delta:', (torch_out - trt_out).abs().mean())

    speed_test(model, x)
    speed_test(trt_model, dict(input=x))

zhe chen's avatar
zhe chen committed
113

PRC-Huang's avatar
PRC-Huang committed
114
115
116
117
118
119
120
121
122
123
124
125
def main():
    args, cfg = get_args()

    if args.onnx or args.trt:
        torch2onnx(args, cfg)
        print('torch -> onnx: succeess')

    if args.trt:
        onnx2trt(args)
        print('onnx -> trt: success')
        check(args, cfg)

zhe chen's avatar
zhe chen committed
126

PRC-Huang's avatar
PRC-Huang committed
127
128
if __name__ == '__main__':
    main()