Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bcc55c52
Unverified
Commit
bcc55c52
authored
Sep 10, 2021
by
Yuge Zhang
Committed by
GitHub
Sep 10, 2021
Browse files
Fix flops counter for `bs>1` (#4154)
parent
e98ebcf0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
31 deletions
+24
-31
nni/compression/pytorch/utils/counter.py
nni/compression/pytorch/utils/counter.py
+14
-23
test/ut/sdk/test_compression_utils.py
test/ut/sdk/test_compression_utils.py
+10
-8
No files found.
nni/compression/pytorch/utils/counter.py
View file @
bcc55c52
...
...
@@ -121,15 +121,15 @@ class ModelProfiler:
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_bn
(
self
,
m
,
x
,
y
):
total_ops
=
2
*
x
[
0
].
numel
()
total_ops
=
2
*
x
[
0
]
[
0
]
.
numel
()
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_relu
(
self
,
m
,
x
,
y
):
total_ops
=
x
[
0
].
numel
()
total_ops
=
x
[
0
]
[
0
]
.
numel
()
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_avgpool
(
self
,
m
,
x
,
y
):
total_ops
=
y
.
numel
()
total_ops
=
y
[
0
]
.
numel
()
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_adap_avgpool
(
self
,
m
,
x
,
y
):
...
...
@@ -137,27 +137,27 @@ class ModelProfiler:
total_add
=
int
(
torch
.
prod
(
kernel
))
total_div
=
1
kernel_ops
=
total_add
+
total_div
num_elements
=
y
.
numel
()
num_elements
=
y
[
0
]
.
numel
()
total_ops
=
kernel_ops
*
num_elements
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_upsample
(
self
,
m
,
x
,
y
):
if
m
.
mode
==
'linear'
:
total_ops
=
y
.
nelement
()
*
5
# 2 muls + 3 add
total_ops
=
y
[
0
]
.
nelement
()
*
5
# 2 muls + 3 add
elif
m
.
mode
==
'bilinear'
:
# https://en.wikipedia.org/wiki/Bilinear_interpolation
total_ops
=
y
.
nelement
()
*
11
# 6 muls + 5 adds
total_ops
=
y
[
0
]
.
nelement
()
*
11
# 6 muls + 5 adds
elif
m
.
mode
==
'bicubic'
:
# https://en.wikipedia.org/wiki/Bicubic_interpolation
# Product matrix [4x4] x [4x4] x [4x4]
ops_solve_A
=
224
# 128 muls + 96 adds
ops_solve_p
=
35
# 16 muls + 12 adds + 4 muls + 3 adds
total_ops
=
y
.
nelement
()
*
(
ops_solve_A
+
ops_solve_p
)
total_ops
=
y
[
0
]
.
nelement
()
*
(
ops_solve_A
+
ops_solve_p
)
elif
m
.
mode
==
'trilinear'
:
# https://en.wikipedia.org/wiki/Trilinear_interpolation
# can viewed as 2 bilinear + 1 linear
total_ops
=
y
.
nelement
()
*
(
13
*
2
+
5
)
total_ops
=
y
[
0
]
.
nelement
()
*
(
13
*
2
+
5
)
else
:
total_ops
=
0
...
...
@@ -202,26 +202,16 @@ class ModelProfiler:
return
total_ops
def
_count_rnn_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'rnn'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_gru_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'gru'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_lstm_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'lstm'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
def
_get_bsize_nsteps
(
self
,
m
,
x
):
...
...
@@ -243,18 +233,17 @@ class ModelProfiler:
hidden_size
=
m
.
hidden_size
num_layers
=
m
.
num_layers
batch_size
,
num_steps
=
self
.
_get_bsize_nsteps
(
m
,
x
)
_
,
num_steps
=
self
.
_get_bsize_nsteps
(
m
,
x
)
total_ops
=
self
.
_count_cell_flops
(
input_size
,
hidden_size
,
module_name
)
for
_
in
range
(
num_layers
-
1
):
if
m
.
bidirectional
:
cell_flops
=
self
.
_count_cell_flops
(
hidden_size
*
2
,
hidden_size
,
module_name
)
*
2
else
:
cell_flops
=
self
.
_count_cell_flops
(
hidden_size
,
hidden_size
,
module_name
)
cell_flops
=
self
.
_count_cell_flops
(
hidden_size
,
hidden_size
,
module_name
)
total_ops
+=
cell_flops
total_ops
*=
num_steps
total_ops
*=
batch_size
return
total_ops
def
_count_rnn
(
self
,
m
,
x
,
y
):
...
...
@@ -272,7 +261,6 @@ class ModelProfiler:
return
self
.
_get_result
(
m
,
total_ops
)
def
count_module
(
self
,
m
,
x
,
y
,
name
):
# assume x is tuple of single tensor
result
=
self
.
ops
[
type
(
m
)](
m
,
x
,
y
)
...
...
@@ -335,7 +323,10 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
identify the mask on the module and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify the remained filters
according to its mask, and do not take the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
so the calculated FLOPs will be larger than real number.
The FLOPs is counted "per sample", which means that input has a batch size larger than 1,
the calculated FLOPs should not differ from batch size of 1.
Parameters
---------
...
...
test/ut/sdk/test_compression_utils.py
View file @
bcc55c52
...
...
@@ -138,6 +138,7 @@ class AnalysisUtilsTest(TestCase):
assert
b_index1
==
b_index2
class
FlopsCounterTest
(
TestCase
):
def
test_flops_params
(
self
):
class
Model1
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -170,16 +171,17 @@ class AnalysisUtilsTest(TestCase):
for
_
in
range
(
5
):
x
=
self
.
conv2
(
x
)
return
x
flops
,
params
,
results
=
count_flops_params
(
Model1
(),
(
1
,
3
,
2
,
2
),
mode
=
'full'
,
verbose
=
False
)
assert
(
flops
,
params
)
==
(
610
,
240
)
flops
,
params
,
results
=
count_flops_params
(
Model2
(),
(
1
,
3
,
2
,
2
),
verbose
=
False
)
assert
(
flops
,
params
)
==
(
560
,
50
)
for
bs
in
[
1
,
2
]:
flops
,
params
,
results
=
count_flops_params
(
Model1
(),
(
bs
,
3
,
2
,
2
),
mode
=
'full'
,
verbose
=
False
)
assert
(
flops
,
params
)
==
(
610
,
240
)
from
torchvision.models
import
resnet50
flops
,
params
,
results
=
count_flops_params
(
resnet50
(),
(
1
,
3
,
224
,
224
),
verbose
=
False
)
assert
(
flops
,
params
)
==
(
4089184256
,
25503912
)
flops
,
params
,
results
=
count_flops_params
(
Model2
(),
(
bs
,
3
,
2
,
2
),
verbose
=
False
)
assert
(
flops
,
params
)
==
(
560
,
50
)
from
torchvision.models
import
resnet50
flops
,
params
,
results
=
count_flops_params
(
resnet50
(),
(
bs
,
3
,
224
,
224
),
verbose
=
False
)
assert
(
flops
,
params
)
==
(
4089184256
,
25503912
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment