"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "449df64875f3255a3bac934efbf9d8f610700ffa"
Unverified Commit bcc55c52 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix flops counter for `bs>1` (#4154)

parent e98ebcf0
...@@ -121,15 +121,15 @@ class ModelProfiler: ...@@ -121,15 +121,15 @@ class ModelProfiler:
return self._get_result(m, total_ops) return self._get_result(m, total_ops)
def _count_bn(self, m, x, y): def _count_bn(self, m, x, y):
total_ops = 2 * x[0].numel() total_ops = 2 * x[0][0].numel()
return self._get_result(m, total_ops) return self._get_result(m, total_ops)
def _count_relu(self, m, x, y): def _count_relu(self, m, x, y):
total_ops = x[0].numel() total_ops = x[0][0].numel()
return self._get_result(m, total_ops) return self._get_result(m, total_ops)
def _count_avgpool(self, m, x, y): def _count_avgpool(self, m, x, y):
total_ops = y.numel() total_ops = y[0].numel()
return self._get_result(m, total_ops) return self._get_result(m, total_ops)
def _count_adap_avgpool(self, m, x, y): def _count_adap_avgpool(self, m, x, y):
...@@ -137,27 +137,27 @@ class ModelProfiler: ...@@ -137,27 +137,27 @@ class ModelProfiler:
total_add = int(torch.prod(kernel)) total_add = int(torch.prod(kernel))
total_div = 1 total_div = 1
kernel_ops = total_add + total_div kernel_ops = total_add + total_div
num_elements = y.numel() num_elements = y[0].numel()
total_ops = kernel_ops * num_elements total_ops = kernel_ops * num_elements
return self._get_result(m, total_ops) return self._get_result(m, total_ops)
def _count_upsample(self, m, x, y): def _count_upsample(self, m, x, y):
if m.mode == 'linear': if m.mode == 'linear':
total_ops = y.nelement() * 5 # 2 muls + 3 add total_ops = y[0].nelement() * 5 # 2 muls + 3 add
elif m.mode == 'bilinear': elif m.mode == 'bilinear':
# https://en.wikipedia.org/wiki/Bilinear_interpolation # https://en.wikipedia.org/wiki/Bilinear_interpolation
total_ops = y.nelement() * 11 # 6 muls + 5 adds total_ops = y[0].nelement() * 11 # 6 muls + 5 adds
elif m.mode == 'bicubic': elif m.mode == 'bicubic':
# https://en.wikipedia.org/wiki/Bicubic_interpolation # https://en.wikipedia.org/wiki/Bicubic_interpolation
# Product matrix [4x4] x [4x4] x [4x4] # Product matrix [4x4] x [4x4] x [4x4]
ops_solve_A = 224 # 128 muls + 96 adds ops_solve_A = 224 # 128 muls + 96 adds
ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds
total_ops = y.nelement() * (ops_solve_A + ops_solve_p) total_ops = y[0].nelement() * (ops_solve_A + ops_solve_p)
elif m.mode == 'trilinear': elif m.mode == 'trilinear':
# https://en.wikipedia.org/wiki/Trilinear_interpolation # https://en.wikipedia.org/wiki/Trilinear_interpolation
# can viewed as 2 bilinear + 1 linear # can viewed as 2 bilinear + 1 linear
total_ops = y.nelement() * (13 * 2 + 5) total_ops = y[0].nelement() * (13 * 2 + 5)
else: else:
total_ops = 0 total_ops = 0
...@@ -202,26 +202,16 @@ class ModelProfiler: ...@@ -202,26 +202,16 @@ class ModelProfiler:
return total_ops return total_ops
def _count_rnn_cell(self, m, x, y): def _count_rnn_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'rnn') total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'rnn')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops) return self._get_result(m, total_ops)
def _count_gru_cell(self, m, x, y): def _count_gru_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'gru') total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'gru')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops) return self._get_result(m, total_ops)
def _count_lstm_cell(self, m, x, y): def _count_lstm_cell(self, m, x, y):
total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'lstm') total_ops = self._count_cell_flops(m.input_size, m.hidden_size, 'lstm')
batch_size = x[0].size(0)
total_ops *= batch_size
return self._get_result(m, total_ops) return self._get_result(m, total_ops)
def _get_bsize_nsteps(self, m, x): def _get_bsize_nsteps(self, m, x):
...@@ -243,18 +233,17 @@ class ModelProfiler: ...@@ -243,18 +233,17 @@ class ModelProfiler:
hidden_size = m.hidden_size hidden_size = m.hidden_size
num_layers = m.num_layers num_layers = m.num_layers
batch_size, num_steps = self._get_bsize_nsteps(m, x) _, num_steps = self._get_bsize_nsteps(m, x)
total_ops = self._count_cell_flops(input_size, hidden_size, module_name) total_ops = self._count_cell_flops(input_size, hidden_size, module_name)
for _ in range(num_layers - 1): for _ in range(num_layers - 1):
if m.bidirectional: if m.bidirectional:
cell_flops = self._count_cell_flops(hidden_size * 2, hidden_size, module_name) * 2 cell_flops = self._count_cell_flops(hidden_size * 2, hidden_size, module_name) * 2
else: else:
cell_flops = self._count_cell_flops(hidden_size, hidden_size,module_name) cell_flops = self._count_cell_flops(hidden_size, hidden_size, module_name)
total_ops += cell_flops total_ops += cell_flops
total_ops *= num_steps total_ops *= num_steps
total_ops *= batch_size
return total_ops return total_ops
def _count_rnn(self, m, x, y): def _count_rnn(self, m, x, y):
...@@ -272,7 +261,6 @@ class ModelProfiler: ...@@ -272,7 +261,6 @@ class ModelProfiler:
return self._get_result(m, total_ops) return self._get_result(m, total_ops)
def count_module(self, m, x, y, name): def count_module(self, m, x, y, name):
# assume x is tuple of single tensor # assume x is tuple of single tensor
result = self.ops[type(m)](m, x, y) result = self.ops[type(m)](m, x, y)
...@@ -337,6 +325,9 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'): ...@@ -337,6 +325,9 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
according to its mask, and do not take the pruned input channels into consideration, according to its mask, and do not take the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number. so the calculated FLOPs will be larger than real number.
The FLOPs is counted "per sample", which means that input has a batch size larger than 1,
the calculated FLOPs should not differ from batch size of 1.
Parameters Parameters
--------- ---------
model : nn.Module model : nn.Module
......
...@@ -138,6 +138,7 @@ class AnalysisUtilsTest(TestCase): ...@@ -138,6 +138,7 @@ class AnalysisUtilsTest(TestCase):
assert b_index1 == b_index2 assert b_index1 == b_index2
class FlopsCounterTest(TestCase):
def test_flops_params(self): def test_flops_params(self):
class Model1(nn.Module): class Model1(nn.Module):
def __init__(self): def __init__(self):
...@@ -171,14 +172,15 @@ class AnalysisUtilsTest(TestCase): ...@@ -171,14 +172,15 @@ class AnalysisUtilsTest(TestCase):
x = self.conv2(x) x = self.conv2(x)
return x return x
flops, params, results = count_flops_params(Model1(), (1, 3, 2, 2), mode='full', verbose=False) for bs in [1, 2]:
flops, params, results = count_flops_params(Model1(), (bs, 3, 2, 2), mode='full', verbose=False)
assert (flops, params) == (610, 240) assert (flops, params) == (610, 240)
flops, params, results = count_flops_params(Model2(), (1, 3, 2, 2), verbose=False) flops, params, results = count_flops_params(Model2(), (bs, 3, 2, 2), verbose=False)
assert (flops, params) == (560, 50) assert (flops, params) == (560, 50)
from torchvision.models import resnet50 from torchvision.models import resnet50
flops, params, results = count_flops_params(resnet50(), (1, 3, 224, 224), verbose=False) flops, params, results = count_flops_params(resnet50(), (bs, 3, 224, 224), verbose=False)
assert (flops, params) == (4089184256, 25503912) assert (flops, params) == (4089184256, 25503912)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment