get_flops.py 3.72 KB
Newer Older
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

zhe chen's avatar
zhe chen committed
4
5
import mmcv_custom  # noqa: F401,F403
import mmseg_custom  # noqa: F401,F403
6
7
8
9
10
11
12
import numpy as np
import torch
from mmcv import Config, DictAction
from mmseg.models import build_segmentor

try:
    from mmcv.cnn import get_model_complexity_info
zhe chen's avatar
zhe chen committed
13
    from mmcv.cnn.utils.flops_counter import flops_to_string, params_to_string
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
except ImportError:
    raise ImportError('Please upgrade mmcv to >0.6.2')


def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[512, 2048],
        help='input image size')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--size-divisor',
        type=int,
        default=32,
        help='Pad the input image, the minimum size that is divisible '
        'by size_divisor, -1 means do not pad the image.')
    args = parser.parse_args()
    return args

zhe chen's avatar
zhe chen committed
46

47
48
49
def dcnv3_flops(n, k, c):
    return 5 * n * k * c

zhe chen's avatar
zhe chen committed
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def get_flops(model, input_shape):
    flops, params = get_model_complexity_info(model, input_shape, as_strings=False)

    backbone = model.backbone
    backbone_name = type(backbone).__name__
    _, H, W = input_shape

    temp = 0
    if 'InternImage' in backbone_name:
        depths = backbone.depths # [4, 4, 18, 4]
        for idx, depth in enumerate(depths):
            channels = backbone.channels * (2 ** idx)
            h = H / (4 * (2 ** idx))
            w = W / (4 * (2 ** idx))
            temp += depth * dcnv3_flops(n=h*w, k=3*3, c=channels)

    flops = flops + temp
    return flops_to_string(flops), params_to_string(params)

zhe chen's avatar
zhe chen committed
70

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
if __name__ == '__main__':

    args = parse_args()

    if len(args.shape) == 1:
        h = w = args.shape[0]
    elif len(args.shape) == 2:
        h, w = args.shape
    else:
        raise ValueError('invalid input shape')
    orig_shape = (3, h, w)
    divisor = args.size_divisor
    if divisor > 0:
        h = int(np.ceil(h / divisor)) * divisor
        w = int(np.ceil(w / divisor)) * divisor

    input_shape = (3, h, w)

    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    model = build_segmentor(
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))
zhe chen's avatar
zhe chen committed
97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    if torch.cuda.is_available():
        model.cuda()
    model.eval()
    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    flops, params = get_flops(model, input_shape)
    split_line = '=' * 30

    if divisor > 0 and \
            input_shape != orig_shape:
        print(f'{split_line}\nUse size divisor set input shape '
              f'from {orig_shape} to {input_shape}\n')
    print(f'{split_line}\nInput shape: {input_shape}\n'
          f'Flops: {flops}\nParams: {params}\n{split_line}')
    print('!!!Please be cautious if you use the results in papers. '
          'You may need to check if all ops are supported and verify that the '
          'flops computation is correct.')