gen_wts.py 3.1 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# encoding: utf-8

import sys
import time
import struct
import argparse
sys.path.append('.')

import torch
import torchvision
#from torchsummary import summary

from fastreid.config import get_cfg
from fastreid.modeling.meta_arch import build_model
from fastreid.utils.checkpoint import Checkpointer

sys.path.append('./projects/FastDistill')
from fastdistill import *

def setup_cfg(args):
    # load confiimport argparseg from file and command-line arguments
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    return cfg

def get_parser():
    parser = argparse.ArgumentParser(description="Encode pytorch weights for tensorrt.")
    parser.add_argument(
        "--config-file",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument(
        "--wts_path",
        default='./trt_demo',
        help='path to save tensorrt weights file(.wts)'
    )
    parser.add_argument(
        "--show_model",
        action='store_true',
        help='print model architecture'
    )
    parser.add_argument(
        "--verify",
        action='store_true',
        help='print model output for verify'
    )
    parser.add_argument(
        "--benchmark",
        action='store_true',
        help='preprocessing + inference time'
    )
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )
    return parser

def gen_wts(args):
    """
        Thanks to https://github.com/wang-xinyu/tensorrtx
    """
    print("Wait for it: {} ...".format(args.wts_path))
    f = open(args.wts_path, 'w')
    f.write("{}\n".format(len(model.state_dict().keys())))
    for k,v in model.state_dict().items():
        #print('key: ', k)
        #print('value: ', v.shape)     
        vr = v.reshape(-1).cpu().numpy()
        f.write("{} {}".format(k, len(vr)))
        for vv in vr:
            f.write(" ")
            f.write(struct.pack(">f", float(vv)).hex())
        f.write("\n")
        
if __name__ == '__main__':
    args = get_parser().parse_args()
    cfg = setup_cfg(args)
    cfg.MODEL.BACKBONE.PRETRAIN = False
    print("[Config]: \n", cfg)
    
    model = build_model(cfg)
    
    if args.show_model:
        print('[Model]: \n', model)
        #summary(model, (3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]))
    
    print("Load model from: ", cfg.MODEL.WEIGHTS)
    Checkpointer(model).load(cfg.MODEL.WEIGHTS)
    
    model = model.to(cfg.MODEL.DEVICE)
    model.eval()
    
    if args.verify:
        input = torch.ones(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(cfg.MODEL.DEVICE) * 255.
        out = model(input).view(-1).cpu().detach().numpy()
        print('[Model output]: \n', out) 
        
    if args.benchmark:
        start_time = time.time()
        input = torch.ones(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(cfg.MODEL.DEVICE) * 255.
        for i in range(100):
            out = model(input).view(-1).cpu().detach()
        print("--- %s seconds ---" % ((time.time() - start_time)/100.) )
    
    gen_wts(args)