Unverified Commit b6233e52 authored by colorjam's avatar colorjam Committed by GitHub
Browse files

Refactor flops counter (#3048)

parent 291bbbba
...@@ -121,14 +121,28 @@ fixed_mask = fix_mask_conflict('./resnet18_mask', net, data) ...@@ -121,14 +121,28 @@ fixed_mask = fix_mask_conflict('./resnet18_mask', net, data)
``` ```
## Model FLOPs/Parameters Counter ## Model FLOPs/Parameters Counter
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup). We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup).
We support two modes to collect information of modules. The first mode is `default`, which only collect the information of convolution and linear. The second mode is `full`, which also collect the information of other operations. Users can easily use our collected `results` for futher analysis.
### Usage ### Usage
``` ```
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
# Given input size (1, 1, 28, 28) # Given input size (1, 1, 28, 28)
flops, params = count_flops_params(model, (1, 1, 28, 28)) flops, params, results = count_flops_params(model, (1, 1, 28, 28))
# Given input tensor with size (1, 1, 28, 28) and switch to full mode
x = torch.randn(1, 1, 28, 28)
flops, params, results = count_flops_params(model, (x,) mode='full') # tuple of tensor as input
# Format output size to M (i.e., 10^6) # Format output size to M (i.e., 10^6)
print(f'FLOPs: {flops/1e6:.3f}M, Params: {params/1e6:.3f}M) print(f'FLOPs: {flops/1e6:.3f}M, Params: {params/1e6:.3f}M)
print(results)
{
'conv': {'flops': [60], 'params': [20], 'weight_size': [(5, 3, 1, 1)], 'input_size': [(1, 3, 2, 2)], 'output_size': [(1, 5, 2, 2)], 'module_type': ['Conv2d']},
'conv2': {'flops': [100], 'params': [30], 'weight_size': [(5, 5, 1, 1)], 'input_size': [(1, 5, 2, 2)], 'output_size': [(1, 5, 2, 2)], 'module_type': ['Conv2d']}
}
``` ```
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import functools
from collections import Counter
from prettytable import PrettyTable
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.compression.pytorch.compressor import PrunerModuleWrapper from nni.compression.pytorch.compressor import PrunerModuleWrapper
try:
from thop import profile
except Exception as e:
print('thop is not found, please install the python package: thop')
raise
__all__ = ['count_flops_params']
def count_flops_params(model: nn.Module, input_size, custom_ops=None, verbose=True):
"""
Count FLOPs and Params of the given model.
This function would identify the mask on the module
and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify
the remained filters according to its mask, which
not taking the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
Parameters def _get_params(m):
--------- return sum([p.numel() for p in m.parameters()])
model : nn.Module
target model.
input_size: list, tuple
the input shape of data
custom_ops: dict
a mapping of (module: custom operation)
the custom operation will overwrite the default operation.
for reference, please see ``custom_mask_ops``.
Returns
-------
flops: float
total flops of the model
params:
total params of the model
"""
assert input_size is not None class ModelProfiler:
device = next(model.parameters()).device def __init__(self, custom_ops=None, mode='default'):
inputs = torch.randn(input_size).to(device) """
ModelProfiler is used to share state to hooks.
hook_module_list = [] Parameters
if custom_ops is None: ----------
custom_ops = {} custom_ops: dict
custom_mask_ops.update(custom_ops) a mapping of (module -> torch.nn.Module : custom operation)
prev_m = None the custom operation is a callback funtion to calculate
for m in model.modules(): the module flops, parameters and the weight shape, it will overwrite the default operation.
weight_mask = None for reference, please see ``self.ops``.
m_type = type(m) mode:
if m_type in custom_mask_ops: the mode of how to collect information. If the mode is set to `default`,
if isinstance(prev_m, PrunerModuleWrapper): only the information of convolution and linear will be collected.
weight_mask = prev_m.weight_mask If the mode is set to `full`, other operations will also be collected.
"""
m.register_buffer('weight_mask', weight_mask) self.ops = {
hook_module_list.append(m) nn.Conv1d: self._count_convNd,
prev_m = m nn.Conv2d: self._count_convNd,
nn.Conv3d: self._count_convNd,
nn.Linear: self._count_linear
}
self._count_bias = False
if mode == 'full':
self.ops.update({
nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.BatchNorm1d: self._count_bn,
nn.BatchNorm2d: self._count_bn,
nn.BatchNorm3d: self._count_bn,
nn.LeakyReLU: self._count_relu,
nn.AvgPool1d: self._count_avgpool,
nn.AvgPool2d: self._count_avgpool,
nn.AvgPool3d: self._count_avgpool,
nn.AdaptiveAvgPool1d: self._count_adap_avgpool,
nn.AdaptiveAvgPool2d: self._count_adap_avgpool,
nn.AdaptiveAvgPool3d: self._count_adap_avgpool,
nn.Upsample: self._count_upsample,
nn.UpsamplingBilinear2d: self._count_upsample,
nn.UpsamplingNearest2d: self._count_upsample
})
self._count_bias = True
flops, params = profile(model, inputs=(inputs, ), custom_ops=custom_mask_ops, verbose=verbose) if custom_ops is not None:
self.ops.update(custom_ops)
self.mode = mode
self.results = []
for m in hook_module_list: def _push_result(self, result):
m._buffers.pop("weight_mask") self.results.append(result)
# Remove registerd buffer on the model, and fixed following issue:
# https://github.com/Lyken17/pytorch-OpCounter/issues/96
for m in model.modules():
if 'total_ops' in m._buffers:
m._buffers.pop("total_ops")
if 'total_params' in m._buffers:
m._buffers.pop("total_params")
return flops, params def _get_result(self, m, flops):
# assume weight is called `weight`, otherwise it's not applicable
# if user customize the operation, the callback function should
# return the dict result, inluding calculated flops, params and weight_shape.
def count_convNd_mask(m, x, y): result = {
""" 'flops': flops,
The forward hook to count FLOPs and Parameters of convolution operation. 'params': _get_params(m),
Parameters 'weight_shape': tuple(m.weight.size()) if hasattr(m, 'weight') else 0,
---------- }
m : torch.nn.Module return result
convolution module to calculate the FLOPs and Parameters
x : torch.Tensor def _count_convNd(self, m, x, y):
input data cin = m.in_channels
y : torch.Tensor kernel_ops = m.weight.size()[2] * m.weight.size()[3]
output data output_size = torch.zeros(y.size()[2:]).numel()
""" cout = y.size()[1]
output_channel = y.size()[1]
output_size = torch.zeros(y.size()[2:]).numel() if hasattr(m, 'weight_mask'):
kernel_size = torch.zeros(m.weight.size()[2:]).numel() cout = m.weight_mask.sum() // (cin * kernel_ops)
total_ops = cout * output_size * kernel_ops * cin // m.groups # cout x oW x oH
if self._count_bias:
bias_flops = 1 if m.bias is not None else 0
total_ops += cout * output_size * bias_flops
return self._get_result(m, total_ops)
def _count_linear(self, m, x, y):
out_features = m.out_features
if hasattr(m, 'weight_mask'):
out_features = m.weight_mask.sum() // m.in_features
total_ops = out_features * m.in_features
if self._count_bias:
bias_flops = 1 if m.bias is not None else 0
total_ops += out_features * bias_flops
return self._get_result(m, total_ops)
def _count_bn(self, m, x, y):
total_ops = 2 * x[0].numel()
return self._get_result(m, total_ops)
def _count_relu(self, m, x, y):
total_ops = x[0].numel()
return self._get_result(m, total_ops)
bias_flops = 1 if m.bias is not None else 0 def _count_avgpool(self, m, x, y):
total_ops = y.numel()
return self._get_result(m, total_ops)
if m.weight_mask is not None: def _count_adap_avgpool(self, m, x, y):
output_channel = m.weight_mask.sum() // (m.in_channels * kernel_size) kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
total_add = int(torch.prod(kernel))
total_div = 1
kernel_ops = total_add + total_div
num_elements = y.numel()
total_ops = kernel_ops * num_elements
total_ops = output_channel * output_size * (m.in_channels // m.groups * kernel_size + bias_flops) return self._get_result(m, total_ops)
m.total_ops += torch.DoubleTensor([int(total_ops)]) def _count_upsample(self, m, x, y):
if m.mode == 'linear':
total_ops = y.nelement() * 5 # 2 muls + 3 add
elif m.mode == 'bilinear':
# https://en.wikipedia.org/wiki/Bilinear_interpolation
total_ops = y.nelement() * 11 # 6 muls + 5 adds
elif m.mode == 'bicubic':
# https://en.wikipedia.org/wiki/Bicubic_interpolation
# Product matrix [4x4] x [4x4] x [4x4]
ops_solve_A = 224 # 128 muls + 96 adds
ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds
total_ops = y.nelement() * (ops_solve_A + ops_solve_p)
elif m.mode == 'trilinear':
# https://en.wikipedia.org/wiki/Trilinear_interpolation
# can viewed as 2 bilinear + 1 linear
total_ops = y.nelement() * (13 * 2 + 5)
else:
total_ops = 0
return self._get_result(m, total_ops)
def count_linear_mask(m, x, y): def count_module(self, m, x, y, name):
# assume x is tuple of single tensor
result = self.ops[type(m)](m, x, y)
total_result = {
'name': name,
'input_size': tuple(x[0].size()),
'output_size': tuple(y.size()),
'module_type': type(m).__name__,
**result
}
self._push_result(total_result)
def sum_flops(self):
return sum([s['flops'] for s in self.results])
def sum_params(self):
return sum({s['name']: s['params'] for s in self.results}.values())
def format_results(self):
table = PrettyTable()
name_counter = Counter([s['name'] for s in self.results])
has_multi_use = any(map(lambda v: v > 1, name_counter.values()))
name_counter = Counter() # clear the counter to count from 0
headers = [
'Index',
'Name',
'Type',
'Weight Shape',
'FLOPs',
'#Params',
]
if has_multi_use:
headers.append('#Call')
table.field_names = headers
for i, result in enumerate(self.results):
row_values = [
i,
result['name'],
result['module_type'],
str(result['weight_shape']),
result['flops'],
result['params'],
]
name_counter[result['name']] += 1
if has_multi_use:
row_values.append(name_counter[result['name']])
table.add_row(row_values)
return table
def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
""" """
The forward hook to count FLOPs and Parameters of linear transformation. Count FLOPs and Params of the given model. This function would
identify the mask on the module and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify the remained filters
according to its mask, and do not take the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
Parameters Parameters
---------- ---------
m : torch.nn.Module model : nn.Module
linear to calculate the FLOPs and Parameters Target model.
x : torch.Tensor x : tuple or tensor
input data The input shape of data (a tuple), a tensor or a tuple of tensor as input data.
y : torch.Tensor custom_ops : dict
output data A mapping of (module -> torch.nn.Module : custom operation)
the custom operation is a callback funtion to calculate
the module flops and parameters, it will overwrite the default operation.
for reference, please see ``ops`` in ``ModelProfiler``.
verbose : bool
If False, mute detail information about modules. Default is True.
mode : str
the mode of how to collect information. If the mode is set to ``default``,
only the information of convolution and linear will be collected.
If the mode is set to ``full``, other operations will also be collected.
Returns
-------
tuple of int, int and dict
Representing total FLOPs, total parameters, and a detailed list of results respectively.
The list of results are a list of dict, each of which contains (name, module_type, weight_shape,
flops, params, input_size, output_size) as its keys.
""" """
output_channel = y.numel()
bias_flops = 1 if m.bias is not None else 0 assert isinstance(x, tuple) or isinstance(x, torch.Tensor)
assert mode in ['default', 'full']
original_device = next(model.parameters()).device
training = model.training
if isinstance(x, tuple) and all(isinstance(t, int) for t in x):
x = (torch.zeros(x).to(original_device), )
elif torch.is_tensor(x):
x = (x.to(original_device), )
else:
x = (t.to(original_device) for t in x)
handler_collection = []
profiler = ModelProfiler(custom_ops, mode)
prev_m = None
for name, m in model.named_modules():
# dealing with weight mask here
if isinstance(prev_m, PrunerModuleWrapper):
# weight mask is set to weight mask of its parent (wrapper)
weight_mask = prev_m.weight_mask
m.weight_mask = weight_mask
prev_m = m
if type(m) in profiler.ops:
# if a leaf node
_handler = m.register_forward_hook(functools.partial(profiler.count_module, name=name))
handler_collection.append(_handler)
model.eval()
if m.weight_mask is not None: with torch.no_grad():
output_channel = m.weight_mask.sum() // m.in_features model(*x)
total_ops = output_channel * (m.in_features + bias_flops) # restore origin status
for name, m in model.named_modules():
if hasattr(m, 'weight_mask'):
delattr(m, 'weight_mask')
m.total_ops += torch.DoubleTensor([int(total_ops)]) model.train(training).to(original_device)
for handler in handler_collection:
handler.remove()
if verbose:
# get detail information
print(profiler.format_results())
print(f'FLOPs total: {profiler.sum_flops()}')
print(f'#Params total: {profiler.sum_params()}')
custom_mask_ops = { return profiler.sum_flops(), profiler.sum_params(), profiler.results
nn.Conv1d: count_convNd_mask, \ No newline at end of file
nn.Conv2d: count_convNd_mask,
nn.Conv3d: count_convNd_mask,
nn.Linear: count_linear_mask,
}
...@@ -72,7 +72,8 @@ dependencies = [ ...@@ -72,7 +72,8 @@ dependencies = [
'colorama', 'colorama',
'scikit-learn>=0.23.2', 'scikit-learn>=0.23.2',
'pkginfo', 'pkginfo',
'websockets' 'websockets',
'prettytable'
] ]
......
...@@ -12,6 +12,7 @@ import numpy as np ...@@ -12,6 +12,7 @@ import numpy as np
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
from nni.compression.pytorch.utils.shape_dependency import ChannelDependency from nni.compression.pytorch.utils.shape_dependency import ChannelDependency
from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict
from nni.compression.pytorch.utils.counter import count_flops_params
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
prefix = 'analysis_test' prefix = 'analysis_test'
...@@ -138,5 +139,49 @@ class AnalysisUtilsTest(TestCase): ...@@ -138,5 +139,49 @@ class AnalysisUtilsTest(TestCase):
assert b_index1 == b_index2 assert b_index1 == b_index2
def test_flops_params(self):
class Model1(nn.Module):
def __init__(self):
super(Model1, self).__init__()
self.conv = nn.Conv2d(3, 5, 1, 1)
self.bn = nn.BatchNorm2d(5)
self.relu = nn.LeakyReLU()
self.linear = nn.Linear(20, 10)
self.upsample = nn.UpsamplingBilinear2d(size=2)
self.pool = nn.AdaptiveAvgPool2d((2, 2))
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.upsample(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
class Model2(nn.Module):
def __init__(self):
super(Model2, self).__init__()
self.conv = nn.Conv2d(3, 5, 1, 1)
self.conv2 = nn.Conv2d(5, 5, 1, 1)
def forward(self, x):
x = self.conv(x)
for _ in range(5):
x = self.conv2(x)
return x
flops, params, results = count_flops_params(Model1(), (1, 3, 2, 2), mode='full', verbose=False)
assert (flops, params) == (610, 240)
flops, params, results = count_flops_params(Model2(), (1, 3, 2, 2), verbose=False)
assert (flops, params) == (560, 50)
from torchvision.models import resnet50
flops, params, results = count_flops_params(resnet50(), (1, 3, 224, 224), verbose=False)
assert (flops, params) == (4089184256, 25503912)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment