"vscode:/vscode.git/clone" did not exist on "04eaae25e00bb9fe5b0e25bd18fc7afd0d0351a4"
Unverified Commit ed9d42a2 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

add GN support for flops computation (#1850)

* add GN support for flops computation

* remove useless lines

* modify the flops computation for gn
parent 9729ca54
...@@ -3,6 +3,6 @@ line_length = 79 ...@@ -3,6 +3,6 @@ line_length = 79
multi_line_output = 0 multi_line_output = 0
known_standard_library = setuptools known_standard_library = setuptools
known_first_party = mmdet known_first_party = mmdet
known_third_party = Cython,albumentations,cv2,imagecorruptions,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision known_third_party = Cython,albumentations,asynctest,cv2,imagecorruptions,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision
no_lines_before = STDLIB,LOCALFOLDER no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY default_section = THIRDPARTY
...@@ -33,19 +33,6 @@ from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin ...@@ -33,19 +33,6 @@ from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
_AvgPoolNd, _MaxPoolNd) _AvgPoolNd, _MaxPoolNd)
CONV_TYPES = (_ConvNd, )
DECONV_TYPES = (_ConvTransposeMixin, )
LINEAR_TYPES = (nn.Linear, )
POOLING_TYPES = (_AvgPoolNd, _MaxPoolNd, _AdaptiveAvgPoolNd,
_AdaptiveMaxPoolNd)
RELU_TYPES = (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)
BN_TYPES = (_BatchNorm, )
UPSAMPLE_TYPES = (nn.Upsample, )
SUPPORTED_TYPES = (
CONV_TYPES + DECONV_TYPES + LINEAR_TYPES + POOLING_TYPES + RELU_TYPES +
BN_TYPES + UPSAMPLE_TYPES)
def get_model_complexity_info(model, def get_model_complexity_info(model,
input_res, input_res,
...@@ -249,10 +236,10 @@ def remove_flops_mask(module): ...@@ -249,10 +236,10 @@ def remove_flops_mask(module):
def is_supported_instance(module): def is_supported_instance(module):
if isinstance(module, SUPPORTED_TYPES): for mod in hook_mapping:
return True if issubclass(type(module), mod):
else: return True
return False return False
def empty_flops_counter_hook(module, input, output): def empty_flops_counter_hook(module, input, output):
...@@ -285,7 +272,6 @@ def pool_flops_counter_hook(module, input, output): ...@@ -285,7 +272,6 @@ def pool_flops_counter_hook(module, input, output):
def bn_flops_counter_hook(module, input, output): def bn_flops_counter_hook(module, input, output):
module.affine
input = input[0] input = input[0]
batch_flops = np.prod(input.shape) batch_flops = np.prod(input.shape)
...@@ -294,6 +280,17 @@ def bn_flops_counter_hook(module, input, output): ...@@ -294,6 +280,17 @@ def bn_flops_counter_hook(module, input, output):
module.__flops__ += int(batch_flops) module.__flops__ += int(batch_flops)
def gn_flops_counter_hook(module, input, output):
elems = np.prod(input[0].shape)
# there is no precise FLOPs estimation of computing mean and variance,
# and we just set it 2 * elems: half muladds for computing
# means and half for computing vars
batch_flops = 3 * elems
if module.affine:
batch_flops += elems
module.__flops__ += int(batch_flops)
def deconv_flops_counter_hook(conv_module, input, output): def deconv_flops_counter_hook(conv_module, input, output):
# Can have multiple inputs, getting the first one # Can have multiple inputs, getting the first one
input = input[0] input = input[0]
...@@ -359,6 +356,32 @@ def conv_flops_counter_hook(conv_module, input, output): ...@@ -359,6 +356,32 @@ def conv_flops_counter_hook(conv_module, input, output):
conv_module.__flops__ += int(overall_flops) conv_module.__flops__ += int(overall_flops)
hook_mapping = {
# conv
_ConvNd: conv_flops_counter_hook,
# deconv
_ConvTransposeMixin: deconv_flops_counter_hook,
# fc
nn.Linear: linear_flops_counter_hook,
# pooling
_AvgPoolNd: pool_flops_counter_hook,
_MaxPoolNd: pool_flops_counter_hook,
_AdaptiveAvgPoolNd: pool_flops_counter_hook,
_AdaptiveMaxPoolNd: pool_flops_counter_hook,
# activation
nn.ReLU: relu_flops_counter_hook,
nn.PReLU: relu_flops_counter_hook,
nn.ELU: relu_flops_counter_hook,
nn.LeakyReLU: relu_flops_counter_hook,
nn.ReLU6: relu_flops_counter_hook,
# normalization
_BatchNorm: bn_flops_counter_hook,
nn.GroupNorm: gn_flops_counter_hook,
# upsample
nn.Upsample: upsample_flops_counter_hook,
}
def batch_counter_hook(module, input, output): def batch_counter_hook(module, input, output):
batch_size = 1 batch_size = 1
if len(input) > 0: if len(input) > 0:
...@@ -372,7 +395,6 @@ def batch_counter_hook(module, input, output): ...@@ -372,7 +395,6 @@ def batch_counter_hook(module, input, output):
def add_batch_counter_variables_or_reset(module): def add_batch_counter_variables_or_reset(module):
module.__batch_counter__ = 0 module.__batch_counter__ = 0
...@@ -400,22 +422,11 @@ def add_flops_counter_hook_function(module): ...@@ -400,22 +422,11 @@ def add_flops_counter_hook_function(module):
if hasattr(module, '__flops_handle__'): if hasattr(module, '__flops_handle__'):
return return
if isinstance(module, CONV_TYPES): for mod_type, counter_hook in hook_mapping.items():
handle = module.register_forward_hook(conv_flops_counter_hook) if issubclass(type(module), mod_type):
elif isinstance(module, RELU_TYPES): handle = module.register_forward_hook(counter_hook)
handle = module.register_forward_hook(relu_flops_counter_hook) break
elif isinstance(module, LINEAR_TYPES):
handle = module.register_forward_hook(linear_flops_counter_hook)
elif isinstance(module, POOLING_TYPES):
handle = module.register_forward_hook(pool_flops_counter_hook)
elif isinstance(module, BN_TYPES):
handle = module.register_forward_hook(bn_flops_counter_hook)
elif isinstance(module, UPSAMPLE_TYPES):
handle = module.register_forward_hook(upsample_flops_counter_hook)
elif isinstance(module, DECONV_TYPES):
handle = module.register_forward_hook(deconv_flops_counter_hook)
else:
handle = module.register_forward_hook(empty_flops_counter_hook)
module.__flops_handle__ = handle module.__flops_handle__ = handle
......
...@@ -46,6 +46,9 @@ def main(): ...@@ -46,6 +46,9 @@ def main():
split_line = '=' * 30 split_line = '=' * 30
print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
split_line, input_shape, flops, params)) split_line, input_shape, flops, params))
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment