Unverified Commit 8008f475 authored by lizz's avatar lizz Committed by GitHub
Browse files

Support mmcv bricks in flops compuation (#715)


Signed-off-by: default avatarlizz <lizz@sensetime.com>
parent 9befc398
......@@ -30,6 +30,8 @@ import numpy as np
import torch
import torch.nn as nn
import mmcv
def get_model_complexity_info(model,
input_shape,
......@@ -365,7 +367,7 @@ def start_flops_count(self):
else:
handle = module.register_forward_hook(
MODULES_MAPPING[type(module)])
get_modules_mapping()[type(module)])
module.__flops_handle__ = handle
......@@ -534,7 +536,7 @@ def add_flops_counter_variable_or_reset(module):
def is_supported_instance(module):
if type(module) in MODULES_MAPPING:
if type(module) in get_modules_mapping():
return True
return False
......@@ -546,38 +548,45 @@ def remove_flops_counter_hook_function(module):
del module.__flops_handle__
MODULES_MAPPING = {
# convolutions
nn.Conv1d: conv_flops_counter_hook,
nn.Conv2d: conv_flops_counter_hook,
nn.Conv3d: conv_flops_counter_hook,
# activations
nn.ReLU: relu_flops_counter_hook,
nn.PReLU: relu_flops_counter_hook,
nn.ELU: relu_flops_counter_hook,
nn.LeakyReLU: relu_flops_counter_hook,
nn.ReLU6: relu_flops_counter_hook,
# poolings
nn.MaxPool1d: pool_flops_counter_hook,
nn.AvgPool1d: pool_flops_counter_hook,
nn.AvgPool2d: pool_flops_counter_hook,
nn.MaxPool2d: pool_flops_counter_hook,
nn.MaxPool3d: pool_flops_counter_hook,
nn.AvgPool3d: pool_flops_counter_hook,
nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
# BNs
nn.BatchNorm1d: bn_flops_counter_hook,
nn.BatchNorm2d: bn_flops_counter_hook,
nn.BatchNorm3d: bn_flops_counter_hook,
# FC
nn.Linear: linear_flops_counter_hook,
# Upscale
nn.Upsample: upsample_flops_counter_hook,
# Deconvolution
nn.ConvTranspose2d: deconv_flops_counter_hook,
}
def get_modules_mapping():
return {
# convolutions
nn.Conv1d: conv_flops_counter_hook,
nn.Conv2d: conv_flops_counter_hook,
mmcv.cnn.bricks.Conv2d: conv_flops_counter_hook,
nn.Conv3d: conv_flops_counter_hook,
mmcv.cnn.bricks.Conv3d: conv_flops_counter_hook,
# activations
nn.ReLU: relu_flops_counter_hook,
nn.PReLU: relu_flops_counter_hook,
nn.ELU: relu_flops_counter_hook,
nn.LeakyReLU: relu_flops_counter_hook,
nn.ReLU6: relu_flops_counter_hook,
# poolings
nn.MaxPool1d: pool_flops_counter_hook,
nn.AvgPool1d: pool_flops_counter_hook,
nn.AvgPool2d: pool_flops_counter_hook,
nn.MaxPool2d: pool_flops_counter_hook,
mmcv.cnn.bricks.MaxPool2d: pool_flops_counter_hook,
nn.MaxPool3d: pool_flops_counter_hook,
mmcv.cnn.bricks.MaxPool3d: pool_flops_counter_hook,
nn.AvgPool3d: pool_flops_counter_hook,
nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
# BNs
nn.BatchNorm1d: bn_flops_counter_hook,
nn.BatchNorm2d: bn_flops_counter_hook,
nn.BatchNorm3d: bn_flops_counter_hook,
# FC
nn.Linear: linear_flops_counter_hook,
mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
# Upscale
nn.Upsample: upsample_flops_counter_hook,
# Deconvolution
nn.ConvTranspose2d: deconv_flops_counter_hook,
mmcv.cnn.bricks.ConvTranspose2d: deconv_flops_counter_hook,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment