Unverified Commit 8008f475 authored by lizz's avatar lizz Committed by GitHub
Browse files

Support mmcv bricks in flops compuation (#715)


Signed-off-by: default avatarlizz <lizz@sensetime.com>
parent 9befc398
...@@ -30,6 +30,8 @@ import numpy as np ...@@ -30,6 +30,8 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import mmcv
def get_model_complexity_info(model, def get_model_complexity_info(model,
input_shape, input_shape,
...@@ -365,7 +367,7 @@ def start_flops_count(self): ...@@ -365,7 +367,7 @@ def start_flops_count(self):
else: else:
handle = module.register_forward_hook( handle = module.register_forward_hook(
MODULES_MAPPING[type(module)]) get_modules_mapping()[type(module)])
module.__flops_handle__ = handle module.__flops_handle__ = handle
...@@ -534,7 +536,7 @@ def add_flops_counter_variable_or_reset(module): ...@@ -534,7 +536,7 @@ def add_flops_counter_variable_or_reset(module):
def is_supported_instance(module): def is_supported_instance(module):
if type(module) in MODULES_MAPPING: if type(module) in get_modules_mapping():
return True return True
return False return False
...@@ -546,38 +548,45 @@ def remove_flops_counter_hook_function(module): ...@@ -546,38 +548,45 @@ def remove_flops_counter_hook_function(module):
del module.__flops_handle__ del module.__flops_handle__
MODULES_MAPPING = { def get_modules_mapping():
# convolutions return {
nn.Conv1d: conv_flops_counter_hook, # convolutions
nn.Conv2d: conv_flops_counter_hook, nn.Conv1d: conv_flops_counter_hook,
nn.Conv3d: conv_flops_counter_hook, nn.Conv2d: conv_flops_counter_hook,
# activations mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook,
nn.ReLU: relu_flops_counter_hook, nn.Conv3d: conv_flops_counter_hook,
nn.PReLU: relu_flops_counter_hook, mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook,
nn.ELU: relu_flops_counter_hook, # activations
nn.LeakyReLU: relu_flops_counter_hook, nn.ReLU: relu_flops_counter_hook,
nn.ReLU6: relu_flops_counter_hook, nn.PReLU: relu_flops_counter_hook,
# poolings nn.ELU: relu_flops_counter_hook,
nn.MaxPool1d: pool_flops_counter_hook, nn.LeakyReLU: relu_flops_counter_hook,
nn.AvgPool1d: pool_flops_counter_hook, nn.ReLU6: relu_flops_counter_hook,
nn.AvgPool2d: pool_flops_counter_hook, # poolings
nn.MaxPool2d: pool_flops_counter_hook, nn.MaxPool1d: pool_flops_counter_hook,
nn.MaxPool3d: pool_flops_counter_hook, nn.AvgPool1d: pool_flops_counter_hook,
nn.AvgPool3d: pool_flops_counter_hook, nn.AvgPool2d: pool_flops_counter_hook,
nn.AdaptiveMaxPool1d: pool_flops_counter_hook, nn.MaxPool2d: pool_flops_counter_hook,
nn.AdaptiveAvgPool1d: pool_flops_counter_hook, mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
nn.AdaptiveMaxPool2d: pool_flops_counter_hook, nn.MaxPool3d: pool_flops_counter_hook,
nn.AdaptiveAvgPool2d: pool_flops_counter_hook, mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
nn.AdaptiveMaxPool3d: pool_flops_counter_hook, nn.AvgPool3d: pool_flops_counter_hook,
nn.AdaptiveAvgPool3d: pool_flops_counter_hook, nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
# BNs nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
nn.BatchNorm1d: bn_flops_counter_hook, nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
nn.BatchNorm2d: bn_flops_counter_hook, nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
nn.BatchNorm3d: bn_flops_counter_hook, nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
# FC nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
nn.Linear: linear_flops_counter_hook, # BNs
# Upscale nn.BatchNorm1d: bn_flops_counter_hook,
nn.Upsample: upsample_flops_counter_hook, nn.BatchNorm2d: bn_flops_counter_hook,
# Deconvolution nn.BatchNorm3d: bn_flops_counter_hook,
nn.ConvTranspose2d: deconv_flops_counter_hook, # FC
} nn.Linear: linear_flops_counter_hook,
mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
# Upscale
nn.Upsample: upsample_flops_counter_hook,
# Deconvolution
nn.ConvTranspose2d: deconv_flops_counter_hook,
mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment