Unverified Commit bcc55c52 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix flops counter for `bs>1` (#4154)

parent e98ebcf0
......@@ -121,15 +121,15 @@ class ModelProfiler:
return self._get_result(m, total_ops)
def _count_bn(self, m, x, y):
total_ops = 2 * x[0].numel()
total_ops = 2 * x[0][0].numel()
return self._get_result(m, total_ops)
def _count_relu(self, m, x, y):
total_ops = x[0].numel()
total_ops = x[0][0].numel()
return self._get_result(m, total_ops)
def _count_avgpool(self, m, x, y):
total_ops = y.numel()
total_ops = y[0].numel()
return self._get_result(m, total_ops)
def _count_adap_avgpool(self, m, x, y):
......@@ -137,27 +137,27 @@ class ModelProfiler:
total_add = int(torch.prod(kernel))
total_div = 1
kernel_ops = total_add + total_div
num_elements = y.numel()
num_elements = y[0].numel()
total_ops = kernel_ops * num_elements
return self._get_result(m, total_ops)
def _count_upsample(self, m, x, y):
if m.mode == 'linear':
total_ops = y.nelement() * 5 # 2 muls + 3 add
total_ops = y[0].nelement() * 5 # 2 muls + 3 add
elif m.mode == 'bilinear':
# https://en.wikipedia.org/wiki/Bilinear_interpolation
total_ops = y.nelement() * 11 # 6 muls + 5 adds
total_ops = y[0].nelement() * 11 # 6 muls + 5 adds
elif m.mode == 'bicubic':
# https://en.wikipedia.org/wiki/Bicubic_interpolation
# Product matrix [4x4] x [4x4] x [4x4]
ops_solve_A = 224 # 128 muls + 96 adds
ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds
total_ops = y.nelement() * (ops_solve_A + ops_solve_p)
total_ops = y[0].nelement() * (ops_solve_A + ops_solve_p)
elif m.mode == 'trilinear':
# https://en.wikipedia.org/wiki/Trilinear_interpolation
# can viewed as 2 bilinear + 1 linear
total_ops = y.nelement() * (13 * 2 + 5)
total_ops = y[0].nelement() * (13 * 2 + 5)
else:
total_ops = 0
......@@ -202,26 +202,16 @@ class ModelProfiler:
return total_ops
def _count_rnn_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'rnn')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops)
def _count_gru_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'gru')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops)
def _count_lstm_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'lstm')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops)
def _get_bsize_nsteps(self, m, x):
......@@ -243,18 +233,17 @@ class ModelProfiler:
hidden_size = m.hidden_size
num_layers = m.num_layers
batch_size, num_steps = self._get_bsize_nsteps(m, x)
_, num_steps = self._get_bsize_nsteps(m, x)
total_ops = self._count_cell_flops(input_size, hidden_size, module_name)
for _ in range(num_layers - 1):
if m.bidirectional:
cell_flops = self._count_cell_flops(hidden_size * 2, hidden_size, module_name) * 2
else:
cell_flops = self._count_cell_flops(hidden_size, hidden_size,module_name)
cell_flops = self._count_cell_flops(hidden_size, hidden_size, module_name)
total_ops += cell_flops
total_ops *= num_steps
total_ops *= batch_size
return total_ops
def _count_rnn(self, m, x, y):
......@@ -272,7 +261,6 @@ class ModelProfiler:
return self._get_result(m, total_ops)
def count_module(self, m, x, y, name):
# assume x is tuple of single tensor
result = self.ops[type(m)](m, x, y)
......@@ -335,7 +323,10 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
identify the mask on the module and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify the remained filters
according to its mask, and do not take the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
so the calculated FLOPs will be larger than real number.
The FLOPs is counted "per sample", which means that input has a batch size larger than 1,
the calculated FLOPs should not differ from batch size of 1.
Parameters
---------
......
......@@ -138,6 +138,7 @@ class AnalysisUtilsTest(TestCase):
assert b_index1 == b_index2
class FlopsCounterTest(TestCase):
def test_flops_params(self):
class Model1(nn.Module):
def __init__(self):
......@@ -170,16 +171,17 @@ class AnalysisUtilsTest(TestCase):
for _ in range(5):
x = self.conv2(x)
return x
flops, params, results = count_flops_params(Model1(), (1, 3, 2, 2), mode='full', verbose=False)
assert (flops, params) == (610, 240)
flops, params, results = count_flops_params(Model2(), (1, 3, 2, 2), verbose=False)
assert (flops, params) == (560, 50)
for bs in [1, 2]:
flops, params, results = count_flops_params(Model1(), (bs, 3, 2, 2), mode='full', verbose=False)
assert (flops, params) == (610, 240)
from torchvision.models import resnet50
flops, params, results = count_flops_params(resnet50(), (1, 3, 224, 224), verbose=False)
assert (flops, params) == (4089184256, 25503912)
flops, params, results = count_flops_params(Model2(), (bs, 3, 2, 2), verbose=False)
assert (flops, params) == (560, 50)
from torchvision.models import resnet50
flops, params, results = count_flops_params(resnet50(), (bs, 3, 224, 224), verbose=False)
assert (flops, params) == (4089184256, 25503912)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment