test_flops.py 1.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## Email: zhanghang0704@gmail.com
## Copyright (c) 2020
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree 
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

import argparse
import torch

from thop import profile, clever_format

import encoding

def get_args():
    # data settings
    parser = argparse.ArgumentParser(description='Deep Encoding')
    parser.add_argument('--crop-size', type=int, default=224,
                        help='crop image size')
    # model params 
    parser.add_argument('--model', type=str, default='densenet',
                        help='network model type (default: densenet)')
    parser.add_argument('--rectify', action='store_true', 
                        default=False, help='rectify convolution')
    parser.add_argument('--rectify-avg', action='store_true', 
                        default=False, help='rectify convolution')
    # checking point
    parser = parser

    args = parser.parse_args()
    return args

def main():
    args = get_args()

    model_kwargs = {}
    if args.rectify:
        model_kwargs['rectified_conv'] = True
        model_kwargs['rectify_avg'] = args.rectify_avg

    model = encoding.models.get_model(args.model, **model_kwargs)
    print(model)

    dummy_images = torch.rand(1, 3, args.crop_size, args.crop_size)

    #count_ops(model, dummy_images, verbose=False)
    macs, params = profile(model, inputs=(dummy_images, ))
    macs, params = clever_format([macs, params], "%.3f") 

    print(f"macs: {macs}, params: {params}")

if __name__ == '__main__':
    main()