"testing/vscode:/vscode.git/clone" did not exist on "93294e61393e349e8ef4caeb3cd3e0e4fad89a10"
Unverified Commit 7fd07766 authored by colorjam's avatar colorjam Committed by GitHub
Browse files

Fix bug of FLOPs counter (#3497)

parent 638da0bd
...@@ -7,6 +7,7 @@ from prettytable import PrettyTable ...@@ -7,6 +7,7 @@ from prettytable import PrettyTable
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
from nni.compression.pytorch.compressor import PrunerModuleWrapper from nni.compression.pytorch.compressor import PrunerModuleWrapper
...@@ -32,21 +33,27 @@ class ModelProfiler: ...@@ -32,21 +33,27 @@ class ModelProfiler:
for reference, please see ``self.ops``. for reference, please see ``self.ops``.
mode: mode:
the mode of how to collect information. If the mode is set to `default`, the mode of how to collect information. If the mode is set to `default`,
only the information of convolution and linear will be collected. only the information of convolution, linear and rnn modules will be collected.
If the mode is set to `full`, other operations will also be collected. If the mode is set to `full`, other operations will also be collected.
""" """
self.ops = { self.ops = {
nn.Conv1d: self._count_convNd, nn.Conv1d: self._count_convNd,
nn.Conv2d: self._count_convNd, nn.Conv2d: self._count_convNd,
nn.Conv3d: self._count_convNd, nn.Conv3d: self._count_convNd,
nn.Linear: self._count_linear nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.Linear: self._count_linear,
nn.RNNCell: self._count_rnn_cell,
nn.GRUCell: self._count_gru_cell,
nn.LSTMCell: self._count_lstm_cell,
nn.RNN: self._count_rnn,
nn.GRU: self._count_gru,
nn.LSTM: self._count_lstm
} }
self._count_bias = False self._count_bias = False
if mode == 'full': if mode == 'full':
self.ops.update({ self.ops.update({
nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.BatchNorm1d: self._count_bn, nn.BatchNorm1d: self._count_bn,
nn.BatchNorm2d: self._count_bn, nn.BatchNorm2d: self._count_bn,
nn.BatchNorm3d: self._count_bn, nn.BatchNorm3d: self._count_bn,
...@@ -86,7 +93,7 @@ class ModelProfiler: ...@@ -86,7 +93,7 @@ class ModelProfiler:
def _count_convNd(self, m, x, y): def _count_convNd(self, m, x, y):
cin = m.in_channels cin = m.in_channels
kernel_ops = m.weight.size()[2] * m.weight.size()[3] kernel_ops = torch.zeros(m.weight.size()[2:]).numel()
output_size = torch.zeros(y.size()[2:]).numel() output_size = torch.zeros(y.size()[2:]).numel()
cout = y.size()[1] cout = y.size()[1]
...@@ -156,13 +163,125 @@ class ModelProfiler: ...@@ -156,13 +163,125 @@ class ModelProfiler:
return self._get_result(m, total_ops) return self._get_result(m, total_ops)
def _count_cell_flops(self, input_size, hidden_size, cell_type):
# h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
total_ops = hidden_size * (input_size + hidden_size) + hidden_size
if self._count_bias:
total_ops += hidden_size * 2
if cell_type == 'rnn':
return total_ops
if cell_type == 'gru':
# r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
# z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
# n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
total_ops *= 3
# r hadamard : r * (~)
total_ops += hidden_size
# h' = (1 - z) * n + z * h
# hadamard hadamard add
total_ops += hidden_size * 3
elif cell_type == 'lstm':
# i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
# f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
# o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
# g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
total_ops *= 4
# c' = f * c + i * g
# hadamard hadamard add
total_ops += hidden_size * 3
# h' = o * \tanh(c')
total_ops += hidden_size
return total_ops
def _count_rnn_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'rnn')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops)
def _count_gru_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'gru')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops)
def _count_lstm_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'lstm')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops)
def _get_bsize_nsteps(self, m, x):
if isinstance(x[0], PackedSequence):
batch_size = torch.max(x[0].batch_sizes)
num_steps = x[0].batch_sizes.size(0)
else:
if m.batch_first:
batch_size = x[0].size(0)
num_steps = x[0].size(1)
else:
batch_size = x[0].size(1)
num_steps = x[0].size(0)
return batch_size, num_steps
def _count_rnn_module(self, m, x, y, module_name):
input_size = m.input_size
hidden_size = m.hidden_size
num_layers = m.num_layers
batch_size, num_steps = self._get_bsize_nsteps(m, x)
total_ops = self._count_cell_flops(input_size, hidden_size, module_name)
for _ in range(num_layers - 1):
if m.bidirectional:
cell_flops = self._count_cell_flops(hidden_size * 2, hidden_size, module_name) * 2
else:
cell_flops = self._count_cell_flops(hidden_size, hidden_size,module_name)
total_ops += cell_flops
total_ops *= num_steps
total_ops *= batch_size
return total_ops
def _count_rnn(self, m, x, y):
total_ops = self._count_rnn_module(m, x, y, 'rnn')
return self._get_result(m, total_ops)
def _count_gru(self, m, x, y):
total_ops = self._count_rnn_module(m, x, y, 'gru')
return self._get_result(m, total_ops)
def _count_lstm(self, m, x, y):
total_ops = self._count_rnn_module(m, x, y, 'lstm')
return self._get_result(m, total_ops)
def count_module(self, m, x, y, name): def count_module(self, m, x, y, name):
# assume x is tuple of single tensor # assume x is tuple of single tensor
result = self.ops[type(m)](m, x, y) result = self.ops[type(m)](m, x, y)
output_size = y[0].size() if isinstance(y, tuple) else y.size()
total_result = { total_result = {
'name': name, 'name': name,
'input_size': tuple(x[0].size()), 'input_size': tuple(x[0].size()),
'output_size': tuple(y.size()), 'output_size': tuple(output_size),
'module_type': type(m).__name__, 'module_type': type(m).__name__,
**result **result
} }
...@@ -279,10 +398,6 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'): ...@@ -279,10 +398,6 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
model(*x) model(*x)
# restore origin status # restore origin status
for name, m in model.named_modules():
if hasattr(m, 'weight_mask'):
delattr(m, 'weight_mask')
model.train(training).to(original_device) model.train(training).to(original_device)
for handler in handler_collection: for handler in handler_collection:
handler.remove() handler.remove()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment