Unverified Commit b6233e52 authored by colorjam's avatar colorjam Committed by GitHub
Browse files

Refactor flops counter (#3048)

parent 291bbbba
......@@ -121,14 +121,28 @@ fixed_mask = fix_mask_conflict('./resnet18_mask', net, data)
```
## Model FLOPs/Parameters Counter
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup).
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup).
We support two modes to collect information of modules. The first mode is `default`, which only collect the information of convolution and linear. The second mode is `full`, which also collect the information of other operations. Users can easily use our collected `results` for futher analysis.
### Usage
```
from nni.compression.pytorch.utils.counter import count_flops_params
# Given input size (1, 1, 28, 28)
flops, params = count_flops_params(model, (1, 1, 28, 28))
# Given input size (1, 1, 28, 28)
flops, params, results = count_flops_params(model, (1, 1, 28, 28))
# Given input tensor with size (1, 1, 28, 28) and switch to full mode
x = torch.randn(1, 1, 28, 28)
flops, params, results = count_flops_params(model, (x,) mode='full') # tuple of tensor as input
# Format output size to M (i.e., 10^6)
print(f'FLOPs: {flops/1e6:.3f}M, Params: {params/1e6:.3f}M)
print(results)
{
'conv': {'flops': [60], 'params': [20], 'weight_size': [(5, 3, 1, 1)], 'input_size': [(1, 3, 2, 2)], 'output_size': [(1, 5, 2, 2)], 'module_type': ['Conv2d']},
'conv2': {'flops': [100], 'params': [30], 'weight_size': [(5, 5, 1, 1)], 'input_size': [(1, 5, 2, 2)], 'output_size': [(1, 5, 2, 2)], 'module_type': ['Conv2d']}
}
```
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
from collections import Counter
from prettytable import PrettyTable
import torch
import torch.nn as nn
from nni.compression.pytorch.compressor import PrunerModuleWrapper
try:
from thop import profile
except Exception as e:
print('thop is not found, please install the python package: thop')
raise
__all__ = ['count_flops_params']
def count_flops_params(model: nn.Module, input_size, custom_ops=None, verbose=True):
"""
Count FLOPs and Params of the given model.
This function would identify the mask on the module
and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify
the remained filters according to its mask, which
not taking the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
Parameters
---------
model : nn.Module
target model.
input_size: list, tuple
the input shape of data
custom_ops: dict
a mapping of (module: custom operation)
the custom operation will overwrite the default operation.
for reference, please see ``custom_mask_ops``.
def _get_params(m):
return sum([p.numel() for p in m.parameters()])
Returns
-------
flops: float
total flops of the model
params:
total params of the model
"""
assert input_size is not None
class ModelProfiler:
device = next(model.parameters()).device
inputs = torch.randn(input_size).to(device)
def __init__(self, custom_ops=None, mode='default'):
"""
ModelProfiler is used to share state to hooks.
hook_module_list = []
if custom_ops is None:
custom_ops = {}
custom_mask_ops.update(custom_ops)
prev_m = None
for m in model.modules():
weight_mask = None
m_type = type(m)
if m_type in custom_mask_ops:
if isinstance(prev_m, PrunerModuleWrapper):
weight_mask = prev_m.weight_mask
m.register_buffer('weight_mask', weight_mask)
hook_module_list.append(m)
prev_m = m
Parameters
----------
custom_ops: dict
a mapping of (module -> torch.nn.Module : custom operation)
the custom operation is a callback funtion to calculate
the module flops, parameters and the weight shape, it will overwrite the default operation.
for reference, please see ``self.ops``.
mode:
the mode of how to collect information. If the mode is set to `default`,
only the information of convolution and linear will be collected.
If the mode is set to `full`, other operations will also be collected.
"""
self.ops = {
nn.Conv1d: self._count_convNd,
nn.Conv2d: self._count_convNd,
nn.Conv3d: self._count_convNd,
nn.Linear: self._count_linear
}
self._count_bias = False
if mode == 'full':
self.ops.update({
nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.BatchNorm1d: self._count_bn,
nn.BatchNorm2d: self._count_bn,
nn.BatchNorm3d: self._count_bn,
nn.LeakyReLU: self._count_relu,
nn.AvgPool1d: self._count_avgpool,
nn.AvgPool2d: self._count_avgpool,
nn.AvgPool3d: self._count_avgpool,
nn.AdaptiveAvgPool1d: self._count_adap_avgpool,
nn.AdaptiveAvgPool2d: self._count_adap_avgpool,
nn.AdaptiveAvgPool3d: self._count_adap_avgpool,
nn.Upsample: self._count_upsample,
nn.UpsamplingBilinear2d: self._count_upsample,
nn.UpsamplingNearest2d: self._count_upsample
})
self._count_bias = True
flops, params = profile(model, inputs=(inputs, ), custom_ops=custom_mask_ops, verbose=verbose)
if custom_ops is not None:
self.ops.update(custom_ops)
self.mode = mode
self.results = []
for m in hook_module_list:
m._buffers.pop("weight_mask")
# Remove registerd buffer on the model, and fixed following issue:
# https://github.com/Lyken17/pytorch-OpCounter/issues/96
for m in model.modules():
if 'total_ops' in m._buffers:
m._buffers.pop("total_ops")
if 'total_params' in m._buffers:
m._buffers.pop("total_params")
def _push_result(self, result):
self.results.append(result)
return flops, params
def _get_result(self, m, flops):
# assume weight is called `weight`, otherwise it's not applicable
# if user customize the operation, the callback function should
# return the dict result, inluding calculated flops, params and weight_shape.
def count_convNd_mask(m, x, y):
"""
The forward hook to count FLOPs and Parameters of convolution operation.
Parameters
----------
m : torch.nn.Module
convolution module to calculate the FLOPs and Parameters
x : torch.Tensor
input data
y : torch.Tensor
output data
"""
output_channel = y.size()[1]
output_size = torch.zeros(y.size()[2:]).numel()
kernel_size = torch.zeros(m.weight.size()[2:]).numel()
result = {
'flops': flops,
'params': _get_params(m),
'weight_shape': tuple(m.weight.size()) if hasattr(m, 'weight') else 0,
}
return result
def _count_convNd(self, m, x, y):
cin = m.in_channels
kernel_ops = m.weight.size()[2] * m.weight.size()[3]
output_size = torch.zeros(y.size()[2:]).numel()
cout = y.size()[1]
if hasattr(m, 'weight_mask'):
cout = m.weight_mask.sum() // (cin * kernel_ops)
total_ops = cout * output_size * kernel_ops * cin // m.groups # cout x oW x oH
if self._count_bias:
bias_flops = 1 if m.bias is not None else 0
total_ops += cout * output_size * bias_flops
return self._get_result(m, total_ops)
def _count_linear(self, m, x, y):
out_features = m.out_features
if hasattr(m, 'weight_mask'):
out_features = m.weight_mask.sum() // m.in_features
total_ops = out_features * m.in_features
if self._count_bias:
bias_flops = 1 if m.bias is not None else 0
total_ops += out_features * bias_flops
return self._get_result(m, total_ops)
def _count_bn(self, m, x, y):
total_ops = 2 * x[0].numel()
return self._get_result(m, total_ops)
def _count_relu(self, m, x, y):
total_ops = x[0].numel()
return self._get_result(m, total_ops)
bias_flops = 1 if m.bias is not None else 0
def _count_avgpool(self, m, x, y):
total_ops = y.numel()
return self._get_result(m, total_ops)
if m.weight_mask is not None:
output_channel = m.weight_mask.sum() // (m.in_channels * kernel_size)
def _count_adap_avgpool(self, m, x, y):
kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
total_add = int(torch.prod(kernel))
total_div = 1
kernel_ops = total_add + total_div
num_elements = y.numel()
total_ops = kernel_ops * num_elements
total_ops = output_channel * output_size * (m.in_channels // m.groups * kernel_size + bias_flops)
return self._get_result(m, total_ops)
m.total_ops += torch.DoubleTensor([int(total_ops)])
def _count_upsample(self, m, x, y):
if m.mode == 'linear':
total_ops = y.nelement() * 5 # 2 muls + 3 add
elif m.mode == 'bilinear':
# https://en.wikipedia.org/wiki/Bilinear_interpolation
total_ops = y.nelement() * 11 # 6 muls + 5 adds
elif m.mode == 'bicubic':
# https://en.wikipedia.org/wiki/Bicubic_interpolation
# Product matrix [4x4] x [4x4] x [4x4]
ops_solve_A = 224 # 128 muls + 96 adds
ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds
total_ops = y.nelement() * (ops_solve_A + ops_solve_p)
elif m.mode == 'trilinear':
# https://en.wikipedia.org/wiki/Trilinear_interpolation
# can viewed as 2 bilinear + 1 linear
total_ops = y.nelement() * (13 * 2 + 5)
else:
total_ops = 0
return self._get_result(m, total_ops)
def count_linear_mask(m, x, y):
def count_module(self, m, x, y, name):
# assume x is tuple of single tensor
result = self.ops[type(m)](m, x, y)
total_result = {
'name': name,
'input_size': tuple(x[0].size()),
'output_size': tuple(y.size()),
'module_type': type(m).__name__,
**result
}
self._push_result(total_result)
def sum_flops(self):
return sum([s['flops'] for s in self.results])
def sum_params(self):
return sum({s['name']: s['params'] for s in self.results}.values())
def format_results(self):
table = PrettyTable()
name_counter = Counter([s['name'] for s in self.results])
has_multi_use = any(map(lambda v: v > 1, name_counter.values()))
name_counter = Counter() # clear the counter to count from 0
headers = [
'Index',
'Name',
'Type',
'Weight Shape',
'FLOPs',
'#Params',
]
if has_multi_use:
headers.append('#Call')
table.field_names = headers
for i, result in enumerate(self.results):
row_values = [
i,
result['name'],
result['module_type'],
str(result['weight_shape']),
result['flops'],
result['params'],
]
name_counter[result['name']] += 1
if has_multi_use:
row_values.append(name_counter[result['name']])
table.add_row(row_values)
return table
def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
"""
The forward hook to count FLOPs and Parameters of linear transformation.
Count FLOPs and Params of the given model. This function would
identify the mask on the module and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify the remained filters
according to its mask, and do not take the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
Parameters
----------
m : torch.nn.Module
linear to calculate the FLOPs and Parameters
x : torch.Tensor
input data
y : torch.Tensor
output data
---------
model : nn.Module
Target model.
x : tuple or tensor
The input shape of data (a tuple), a tensor or a tuple of tensor as input data.
custom_ops : dict
A mapping of (module -> torch.nn.Module : custom operation)
the custom operation is a callback funtion to calculate
the module flops and parameters, it will overwrite the default operation.
for reference, please see ``ops`` in ``ModelProfiler``.
verbose : bool
If False, mute detail information about modules. Default is True.
mode : str
the mode of how to collect information. If the mode is set to ``default``,
only the information of convolution and linear will be collected.
If the mode is set to ``full``, other operations will also be collected.
Returns
-------
tuple of int, int and dict
Representing total FLOPs, total parameters, and a detailed list of results respectively.
The list of results are a list of dict, each of which contains (name, module_type, weight_shape,
flops, params, input_size, output_size) as its keys.
"""
output_channel = y.numel()
bias_flops = 1 if m.bias is not None else 0
assert isinstance(x, tuple) or isinstance(x, torch.Tensor)
assert mode in ['default', 'full']
original_device = next(model.parameters()).device
training = model.training
if isinstance(x, tuple) and all(isinstance(t, int) for t in x):
x = (torch.zeros(x).to(original_device), )
elif torch.is_tensor(x):
x = (x.to(original_device), )
else:
x = (t.to(original_device) for t in x)
handler_collection = []
profiler = ModelProfiler(custom_ops, mode)
prev_m = None
for name, m in model.named_modules():
# dealing with weight mask here
if isinstance(prev_m, PrunerModuleWrapper):
# weight mask is set to weight mask of its parent (wrapper)
weight_mask = prev_m.weight_mask
m.weight_mask = weight_mask
prev_m = m
if type(m) in profiler.ops:
# if a leaf node
_handler = m.register_forward_hook(functools.partial(profiler.count_module, name=name))
handler_collection.append(_handler)
model.eval()
if m.weight_mask is not None:
output_channel = m.weight_mask.sum() // m.in_features
with torch.no_grad():
model(*x)
total_ops = output_channel * (m.in_features + bias_flops)
# restore origin status
for name, m in model.named_modules():
if hasattr(m, 'weight_mask'):
delattr(m, 'weight_mask')
m.total_ops += torch.DoubleTensor([int(total_ops)])
model.train(training).to(original_device)
for handler in handler_collection:
handler.remove()
if verbose:
# get detail information
print(profiler.format_results())
print(f'FLOPs total: {profiler.sum_flops()}')
print(f'#Params total: {profiler.sum_params()}')
custom_mask_ops = {
nn.Conv1d: count_convNd_mask,
nn.Conv2d: count_convNd_mask,
nn.Conv3d: count_convNd_mask,
nn.Linear: count_linear_mask,
}
return profiler.sum_flops(), profiler.sum_params(), profiler.results
\ No newline at end of file
......@@ -72,7 +72,8 @@ dependencies = [
'colorama',
'scikit-learn>=0.23.2',
'pkginfo',
'websockets'
'websockets',
'prettytable'
]
......
......@@ -12,6 +12,7 @@ import numpy as np
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
from nni.compression.pytorch.utils.shape_dependency import ChannelDependency
from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict
from nni.compression.pytorch.utils.counter import count_flops_params
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
prefix = 'analysis_test'
......@@ -138,5 +139,49 @@ class AnalysisUtilsTest(TestCase):
assert b_index1 == b_index2
def test_flops_params(self):
class Model1(nn.Module):
def __init__(self):
super(Model1, self).__init__()
self.conv = nn.Conv2d(3, 5, 1, 1)
self.bn = nn.BatchNorm2d(5)
self.relu = nn.LeakyReLU()
self.linear = nn.Linear(20, 10)
self.upsample = nn.UpsamplingBilinear2d(size=2)
self.pool = nn.AdaptiveAvgPool2d((2, 2))
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.upsample(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
class Model2(nn.Module):
def __init__(self):
super(Model2, self).__init__()
self.conv = nn.Conv2d(3, 5, 1, 1)
self.conv2 = nn.Conv2d(5, 5, 1, 1)
def forward(self, x):
x = self.conv(x)
for _ in range(5):
x = self.conv2(x)
return x
flops, params, results = count_flops_params(Model1(), (1, 3, 2, 2), mode='full', verbose=False)
assert (flops, params) == (610, 240)
flops, params, results = count_flops_params(Model2(), (1, 3, 2, 2), verbose=False)
assert (flops, params) == (560, 50)
from torchvision.models import resnet50
flops, params, results = count_flops_params(resnet50(), (1, 3, 224, 224), verbose=False)
assert (flops, params) == (4089184256, 25503912)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment