get_flops.py 3.75 KB
Newer Older
1
2
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
zhe chen's avatar
zhe chen committed
3
from functools import partial
4

zhe chen's avatar
zhe chen committed
5
6
import mmcv_custom  # noqa: F401,F403
import mmseg_custom  # noqa: F401,F403
7
8
9
10
11
12
13
import numpy as np
import torch
from mmcv import Config, DictAction
from mmseg.models import build_segmentor

try:
    from mmcv.cnn import get_model_complexity_info
zhe chen's avatar
zhe chen committed
14
    from mmcv.cnn.utils.flops_counter import flops_to_string, params_to_string
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
except ImportError:
    raise ImportError('Please upgrade mmcv to >0.6.2')


def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[512, 2048],
        help='input image size')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--size-divisor',
        type=int,
        default=32,
        help='Pad the input image, the minimum size that is divisible '
        'by size_divisor, -1 means do not pad the image.')
    args = parser.parse_args()
    return args

zhe chen's avatar
zhe chen committed
47

48
49
50
def dcnv3_flops(n, k, c):
    return 5 * n * k * c

zhe chen's avatar
zhe chen committed
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def get_flops(model, input_shape):
    flops, params = get_model_complexity_info(model, input_shape, as_strings=False)

    backbone = model.backbone
    backbone_name = type(backbone).__name__
    _, H, W = input_shape

    temp = 0
    if 'InternImage' in backbone_name:
        depths = backbone.depths # [4, 4, 18, 4]
        for idx, depth in enumerate(depths):
            channels = backbone.channels * (2 ** idx)
            h = H / (4 * (2 ** idx))
            w = W / (4 * (2 ** idx))
            temp += depth * dcnv3_flops(n=h*w, k=3*3, c=channels)

    flops = flops + temp
    return flops_to_string(flops), params_to_string(params)

zhe chen's avatar
zhe chen committed
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
if __name__ == '__main__':

    args = parse_args()

    if len(args.shape) == 1:
        h = w = args.shape[0]
    elif len(args.shape) == 2:
        h, w = args.shape
    else:
        raise ValueError('invalid input shape')
    orig_shape = (3, h, w)
    divisor = args.size_divisor
    if divisor > 0:
        h = int(np.ceil(h / divisor)) * divisor
        w = int(np.ceil(w / divisor)) * divisor

    input_shape = (3, h, w)

    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    model = build_segmentor(
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))
zhe chen's avatar
zhe chen committed
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    if torch.cuda.is_available():
        model.cuda()
    model.eval()
    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    flops, params = get_flops(model, input_shape)
    split_line = '=' * 30

    if divisor > 0 and \
            input_shape != orig_shape:
        print(f'{split_line}\nUse size divisor set input shape '
              f'from {orig_shape} to {input_shape}\n')
    print(f'{split_line}\nInput shape: {input_shape}\n'
          f'Flops: {flops}\nParams: {params}\n{split_line}')
    print('!!!Please be cautious if you use the results in papers. '
          'You may need to check if all ops are supported and verify that the '
          'flops computation is correct.')