Unverified Commit b6233e52 authored by colorjam's avatar colorjam Committed by GitHub
Browse files

Refactor flops counter (#3048)

parent 291bbbba
...@@ -123,12 +123,26 @@ fixed_mask = fix_mask_conflict('./resnet18_mask', net, data) ...@@ -123,12 +123,26 @@ fixed_mask = fix_mask_conflict('./resnet18_mask', net, data)
## Model FLOPs/Parameters Counter ## Model FLOPs/Parameters Counter
We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup). We provide a model counter for calculating the model FLOPs and parameters. This counter supports calculating FLOPs/parameters of a normal model without masks, it can also calculates FLOPs/parameters of a model with mask wrappers, which helps users easily check model complexity during model compression on NNI. Note that, for sturctured pruning, we only identify the remained filters according to its mask, which not taking the pruned input channels into consideration, so the calculated FLOPs will be larger than real number (i.e., the number calculated after Model Speedup).
We support two modes to collect information of modules. The first mode is `default`, which only collect the information of convolution and linear. The second mode is `full`, which also collect the information of other operations. Users can easily use our collected `results` for futher analysis.
### Usage ### Usage
``` ```
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
# Given input size (1, 1, 28, 28) # Given input size (1, 1, 28, 28)
flops, params = count_flops_params(model, (1, 1, 28, 28)) flops, params, results = count_flops_params(model, (1, 1, 28, 28))
# Given input tensor with size (1, 1, 28, 28) and switch to full mode
x = torch.randn(1, 1, 28, 28)
flops, params, results = count_flops_params(model, (x,) mode='full') # tuple of tensor as input
# Format output size to M (i.e., 10^6) # Format output size to M (i.e., 10^6)
print(f'FLOPs: {flops/1e6:.3f}M, Params: {params/1e6:.3f}M) print(f'FLOPs: {flops/1e6:.3f}M, Params: {params/1e6:.3f}M)
print(results)
{
'conv': {'flops': [60], 'params': [20], 'weight_size': [(5, 3, 1, 1)], 'input_size': [(1, 3, 2, 2)], 'output_size': [(1, 5, 2, 2)], 'module_type': ['Conv2d']},
'conv2': {'flops': [100], 'params': [30], 'weight_size': [(5, 5, 1, 1)], 'input_size': [(1, 5, 2, 2)], 'output_size': [(1, 5, 2, 2)], 'module_type': ['Conv2d']}
}
``` ```
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import functools
from collections import Counter
from prettytable import PrettyTable
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.compression.pytorch.compressor import PrunerModuleWrapper from nni.compression.pytorch.compressor import PrunerModuleWrapper
try:
from thop import profile __all__ = ['count_flops_params']
except Exception as e:
print('thop is not found, please install the python package: thop')
raise def _get_params(m):
return sum([p.numel() for p in m.parameters()])
def count_flops_params(model: nn.Module, input_size, custom_ops=None, verbose=True): class ModelProfiler:
def __init__(self, custom_ops=None, mode='default'):
""" """
Count FLOPs and Params of the given model. ModelProfiler is used to share state to hooks.
This function would identify the mask on the module
and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify
the remained filters according to its mask, which
not taking the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
Parameters Parameters
--------- ----------
model : nn.Module
target model.
input_size: list, tuple
the input shape of data
custom_ops: dict custom_ops: dict
a mapping of (module: custom operation) a mapping of (module -> torch.nn.Module : custom operation)
the custom operation will overwrite the default operation. the custom operation is a callback funtion to calculate
for reference, please see ``custom_mask_ops``. the module flops, parameters and the weight shape, it will overwrite the default operation.
for reference, please see ``self.ops``.
Returns mode:
------- the mode of how to collect information. If the mode is set to `default`,
flops: float only the information of convolution and linear will be collected.
total flops of the model If the mode is set to `full`, other operations will also be collected.
params:
total params of the model
""" """
self.ops = {
nn.Conv1d: self._count_convNd,
nn.Conv2d: self._count_convNd,
nn.Conv3d: self._count_convNd,
nn.Linear: self._count_linear
}
self._count_bias = False
if mode == 'full':
self.ops.update({
nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.BatchNorm1d: self._count_bn,
nn.BatchNorm2d: self._count_bn,
nn.BatchNorm3d: self._count_bn,
nn.LeakyReLU: self._count_relu,
nn.AvgPool1d: self._count_avgpool,
nn.AvgPool2d: self._count_avgpool,
nn.AvgPool3d: self._count_avgpool,
nn.AdaptiveAvgPool1d: self._count_adap_avgpool,
nn.AdaptiveAvgPool2d: self._count_adap_avgpool,
nn.AdaptiveAvgPool3d: self._count_adap_avgpool,
nn.Upsample: self._count_upsample,
nn.UpsamplingBilinear2d: self._count_upsample,
nn.UpsamplingNearest2d: self._count_upsample
})
self._count_bias = True
assert input_size is not None if custom_ops is not None:
self.ops.update(custom_ops)
device = next(model.parameters()).device self.mode = mode
inputs = torch.randn(input_size).to(device) self.results = []
hook_module_list = [] def _push_result(self, result):
if custom_ops is None: self.results.append(result)
custom_ops = {}
custom_mask_ops.update(custom_ops)
prev_m = None
for m in model.modules():
weight_mask = None
m_type = type(m)
if m_type in custom_mask_ops:
if isinstance(prev_m, PrunerModuleWrapper):
weight_mask = prev_m.weight_mask
m.register_buffer('weight_mask', weight_mask) def _get_result(self, m, flops):
hook_module_list.append(m) # assume weight is called `weight`, otherwise it's not applicable
prev_m = m # if user customize the operation, the callback function should
# return the dict result, inluding calculated flops, params and weight_shape.
flops, params = profile(model, inputs=(inputs, ), custom_ops=custom_mask_ops, verbose=verbose) result = {
'flops': flops,
'params': _get_params(m),
'weight_shape': tuple(m.weight.size()) if hasattr(m, 'weight') else 0,
}
return result
def _count_convNd(self, m, x, y):
cin = m.in_channels
kernel_ops = m.weight.size()[2] * m.weight.size()[3]
output_size = torch.zeros(y.size()[2:]).numel()
cout = y.size()[1]
for m in hook_module_list: if hasattr(m, 'weight_mask'):
m._buffers.pop("weight_mask") cout = m.weight_mask.sum() // (cin * kernel_ops)
# Remove registerd buffer on the model, and fixed following issue:
# https://github.com/Lyken17/pytorch-OpCounter/issues/96
for m in model.modules():
if 'total_ops' in m._buffers:
m._buffers.pop("total_ops")
if 'total_params' in m._buffers:
m._buffers.pop("total_params")
return flops, params total_ops = cout * output_size * kernel_ops * cin // m.groups # cout x oW x oH
def count_convNd_mask(m, x, y): if self._count_bias:
""" bias_flops = 1 if m.bias is not None else 0
The forward hook to count FLOPs and Parameters of convolution operation. total_ops += cout * output_size * bias_flops
Parameters
---------- return self._get_result(m, total_ops)
m : torch.nn.Module
convolution module to calculate the FLOPs and Parameters def _count_linear(self, m, x, y):
x : torch.Tensor out_features = m.out_features
input data if hasattr(m, 'weight_mask'):
y : torch.Tensor out_features = m.weight_mask.sum() // m.in_features
output data total_ops = out_features * m.in_features
"""
output_channel = y.size()[1]
output_size = torch.zeros(y.size()[2:]).numel()
kernel_size = torch.zeros(m.weight.size()[2:]).numel()
if self._count_bias:
bias_flops = 1 if m.bias is not None else 0 bias_flops = 1 if m.bias is not None else 0
total_ops += out_features * bias_flops
return self._get_result(m, total_ops)
def _count_bn(self, m, x, y):
total_ops = 2 * x[0].numel()
return self._get_result(m, total_ops)
def _count_relu(self, m, x, y):
total_ops = x[0].numel()
return self._get_result(m, total_ops)
if m.weight_mask is not None: def _count_avgpool(self, m, x, y):
output_channel = m.weight_mask.sum() // (m.in_channels * kernel_size) total_ops = y.numel()
return self._get_result(m, total_ops)
total_ops = output_channel * output_size * (m.in_channels // m.groups * kernel_size + bias_flops) def _count_adap_avgpool(self, m, x, y):
kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
total_add = int(torch.prod(kernel))
total_div = 1
kernel_ops = total_add + total_div
num_elements = y.numel()
total_ops = kernel_ops * num_elements
m.total_ops += torch.DoubleTensor([int(total_ops)]) return self._get_result(m, total_ops)
def _count_upsample(self, m, x, y):
if m.mode == 'linear':
total_ops = y.nelement() * 5 # 2 muls + 3 add
elif m.mode == 'bilinear':
# https://en.wikipedia.org/wiki/Bilinear_interpolation
total_ops = y.nelement() * 11 # 6 muls + 5 adds
elif m.mode == 'bicubic':
# https://en.wikipedia.org/wiki/Bicubic_interpolation
# Product matrix [4x4] x [4x4] x [4x4]
ops_solve_A = 224 # 128 muls + 96 adds
ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds
total_ops = y.nelement() * (ops_solve_A + ops_solve_p)
elif m.mode == 'trilinear':
# https://en.wikipedia.org/wiki/Trilinear_interpolation
# can viewed as 2 bilinear + 1 linear
total_ops = y.nelement() * (13 * 2 + 5)
else:
total_ops = 0
def count_linear_mask(m, x, y): return self._get_result(m, total_ops)
def count_module(self, m, x, y, name):
# assume x is tuple of single tensor
result = self.ops[type(m)](m, x, y)
total_result = {
'name': name,
'input_size': tuple(x[0].size()),
'output_size': tuple(y.size()),
'module_type': type(m).__name__,
**result
}
self._push_result(total_result)
def sum_flops(self):
return sum([s['flops'] for s in self.results])
def sum_params(self):
return sum({s['name']: s['params'] for s in self.results}.values())
def format_results(self):
table = PrettyTable()
name_counter = Counter([s['name'] for s in self.results])
has_multi_use = any(map(lambda v: v > 1, name_counter.values()))
name_counter = Counter() # clear the counter to count from 0
headers = [
'Index',
'Name',
'Type',
'Weight Shape',
'FLOPs',
'#Params',
]
if has_multi_use:
headers.append('#Call')
table.field_names = headers
for i, result in enumerate(self.results):
row_values = [
i,
result['name'],
result['module_type'],
str(result['weight_shape']),
result['flops'],
result['params'],
]
name_counter[result['name']] += 1
if has_multi_use:
row_values.append(name_counter[result['name']])
table.add_row(row_values)
return table
def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
""" """
The forward hook to count FLOPs and Parameters of linear transformation. Count FLOPs and Params of the given model. This function would
identify the mask on the module and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify the remained filters
according to its mask, and do not take the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
Parameters Parameters
---------- ---------
m : torch.nn.Module model : nn.Module
linear to calculate the FLOPs and Parameters Target model.
x : torch.Tensor x : tuple or tensor
input data The input shape of data (a tuple), a tensor or a tuple of tensor as input data.
y : torch.Tensor custom_ops : dict
output data A mapping of (module -> torch.nn.Module : custom operation)
the custom operation is a callback funtion to calculate
the module flops and parameters, it will overwrite the default operation.
for reference, please see ``ops`` in ``ModelProfiler``.
verbose : bool
If False, mute detail information about modules. Default is True.
mode : str
the mode of how to collect information. If the mode is set to ``default``,
only the information of convolution and linear will be collected.
If the mode is set to ``full``, other operations will also be collected.
Returns
-------
tuple of int, int and dict
Representing total FLOPs, total parameters, and a detailed list of results respectively.
The list of results are a list of dict, each of which contains (name, module_type, weight_shape,
flops, params, input_size, output_size) as its keys.
""" """
output_channel = y.numel()
bias_flops = 1 if m.bias is not None else 0 assert isinstance(x, tuple) or isinstance(x, torch.Tensor)
assert mode in ['default', 'full']
original_device = next(model.parameters()).device
training = model.training
if isinstance(x, tuple) and all(isinstance(t, int) for t in x):
x = (torch.zeros(x).to(original_device), )
elif torch.is_tensor(x):
x = (x.to(original_device), )
else:
x = (t.to(original_device) for t in x)
handler_collection = []
profiler = ModelProfiler(custom_ops, mode)
prev_m = None
for name, m in model.named_modules():
# dealing with weight mask here
if isinstance(prev_m, PrunerModuleWrapper):
# weight mask is set to weight mask of its parent (wrapper)
weight_mask = prev_m.weight_mask
m.weight_mask = weight_mask
prev_m = m
if type(m) in profiler.ops:
# if a leaf node
_handler = m.register_forward_hook(functools.partial(profiler.count_module, name=name))
handler_collection.append(_handler)
model.eval()
if m.weight_mask is not None: with torch.no_grad():
output_channel = m.weight_mask.sum() // m.in_features model(*x)
total_ops = output_channel * (m.in_features + bias_flops) # restore origin status
for name, m in model.named_modules():
if hasattr(m, 'weight_mask'):
delattr(m, 'weight_mask')
m.total_ops += torch.DoubleTensor([int(total_ops)]) model.train(training).to(original_device)
for handler in handler_collection:
handler.remove()
if verbose:
# get detail information
print(profiler.format_results())
print(f'FLOPs total: {profiler.sum_flops()}')
print(f'#Params total: {profiler.sum_params()}')
custom_mask_ops = { return profiler.sum_flops(), profiler.sum_params(), profiler.results
nn.Conv1d: count_convNd_mask, \ No newline at end of file
nn.Conv2d: count_convNd_mask,
nn.Conv3d: count_convNd_mask,
nn.Linear: count_linear_mask,
}
...@@ -72,7 +72,8 @@ dependencies = [ ...@@ -72,7 +72,8 @@ dependencies = [
'colorama', 'colorama',
'scikit-learn>=0.23.2', 'scikit-learn>=0.23.2',
'pkginfo', 'pkginfo',
'websockets' 'websockets',
'prettytable'
] ]
......
...@@ -12,6 +12,7 @@ import numpy as np ...@@ -12,6 +12,7 @@ import numpy as np
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
from nni.compression.pytorch.utils.shape_dependency import ChannelDependency from nni.compression.pytorch.utils.shape_dependency import ChannelDependency
from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict
from nni.compression.pytorch.utils.counter import count_flops_params
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
prefix = 'analysis_test' prefix = 'analysis_test'
...@@ -138,5 +139,49 @@ class AnalysisUtilsTest(TestCase): ...@@ -138,5 +139,49 @@ class AnalysisUtilsTest(TestCase):
assert b_index1 == b_index2 assert b_index1 == b_index2
def test_flops_params(self):
class Model1(nn.Module):
def __init__(self):
super(Model1, self).__init__()
self.conv = nn.Conv2d(3, 5, 1, 1)
self.bn = nn.BatchNorm2d(5)
self.relu = nn.LeakyReLU()
self.linear = nn.Linear(20, 10)
self.upsample = nn.UpsamplingBilinear2d(size=2)
self.pool = nn.AdaptiveAvgPool2d((2, 2))
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.upsample(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
class Model2(nn.Module):
def __init__(self):
super(Model2, self).__init__()
self.conv = nn.Conv2d(3, 5, 1, 1)
self.conv2 = nn.Conv2d(5, 5, 1, 1)
def forward(self, x):
x = self.conv(x)
for _ in range(5):
x = self.conv2(x)
return x
flops, params, results = count_flops_params(Model1(), (1, 3, 2, 2), mode='full', verbose=False)
assert (flops, params) == (610, 240)
flops, params, results = count_flops_params(Model2(), (1, 3, 2, 2), verbose=False)
assert (flops, params) == (560, 50)
from torchvision.models import resnet50
flops, params, results = count_flops_params(resnet50(), (1, 3, 224, 224), verbose=False)
assert (flops, params) == (4089184256, 25503912)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment