"detection/ops_dcnv3/modules/__init__.py" did not exist on "5ba0b54748c8635c1d626d0bf409f8642896a4c3"
get_flops.py 3.74 KB
Newer Older
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

zhe chen's avatar
zhe chen committed
4
5
import mmcv_custom  # noqa: F401,F403
import mmdet_custom  # noqa: F401,F403
6
7
8
9
10
11
12
import numpy as np
import torch
from mmcv import Config, DictAction
from mmdet.models import build_detector

try:
    from mmcv.cnn import get_model_complexity_info
zhe chen's avatar
zhe chen committed
13
    from mmcv.cnn.utils.flops_counter import flops_to_string, params_to_string
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
except ImportError:
    raise ImportError('Please upgrade mmcv to >0.6.2')


def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[800, 1280],
        help='input image size')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
             'in xxx=yyy format will be merged into config file. If the value to '
             'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
             'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
             'Note that the quotation marks are necessary and that no white space '
             'is allowed.')
    parser.add_argument(
        '--size-divisor',
        type=int,
        default=32,
        help='Pad the input image, the minimum size that is divisible '
             'by size_divisor, -1 means do not pad the image.')
    args = parser.parse_args()
    return args


def dcnv3_flops(n, k, c):
    return 5 * n * k * c


def get_flops(model, input_shape):
    flops, params = get_model_complexity_info(model, input_shape, as_strings=False)
zhe chen's avatar
zhe chen committed
53

54
55
56
    backbone = model.backbone
    backbone_name = type(backbone).__name__
    _, H, W = input_shape
zhe chen's avatar
zhe chen committed
57

58
59
60
61
62
63
64
65
    temp = 0
    if 'InternImage' in backbone_name:
        depths = backbone.depths  # [4, 4, 18, 4]
        for idx, depth in enumerate(depths):
            channels = backbone.channels * (2 ** idx)
            h = H / (4 * (2 ** idx))
            w = W / (4 * (2 ** idx))
            temp += depth * dcnv3_flops(n=h * w, k=3 * 3, c=channels)
zhe chen's avatar
zhe chen committed
66

67
68
69
70
71
    flops = flops + temp
    return flops_to_string(flops), params_to_string(params)


if __name__ == '__main__':
zhe chen's avatar
zhe chen committed
72

73
    args = parse_args()
zhe chen's avatar
zhe chen committed
74

75
76
77
78
79
80
81
82
83
84
85
    if len(args.shape) == 1:
        h = w = args.shape[0]
    elif len(args.shape) == 2:
        h, w = args.shape
    else:
        raise ValueError('invalid input shape')
    orig_shape = (3, h, w)
    divisor = args.size_divisor
    if divisor > 0:
        h = int(np.ceil(h / divisor)) * divisor
        w = int(np.ceil(w / divisor)) * divisor
zhe chen's avatar
zhe chen committed
86

87
    input_shape = (3, h, w)
zhe chen's avatar
zhe chen committed
88

89
90
91
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
zhe chen's avatar
zhe chen committed
92

93
94
95
96
    model = build_detector(
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))
zhe chen's avatar
zhe chen committed
97

98
99
100
101
102
103
104
    if torch.cuda.is_available():
        model.cuda()
    model.eval()
    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
zhe chen's avatar
zhe chen committed
105
106
            'FLOPs counter is currently not currently supported with {}'.format(model.__class__.__name__))

107
108
    flops, params = get_flops(model, input_shape)
    split_line = '=' * 30
zhe chen's avatar
zhe chen committed
109

110
111
112
113
114
115
116
117
118
    if divisor > 0 and \
            input_shape != orig_shape:
        print(f'{split_line}\nUse size divisor set input shape '
              f'from {orig_shape} to {input_shape}\n')
    print(f'{split_line}\nInput shape: {input_shape}\n'
          f'Flops: {flops}\nParams: {params}\n{split_line}')
    print('!!!Please be cautious if you use the results in papers. '
          'You may need to check if all ops are supported and verify that the '
          'flops computation is correct.')