Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bcc55c52
"src/nni_manager/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "cf3d434f767d0de9809ddf0e7b9445e7bc36a258"
Unverified
Commit
bcc55c52
authored
Sep 10, 2021
by
Yuge Zhang
Committed by
GitHub
Sep 10, 2021
Browse files
Fix flops counter for `bs>1` (#4154)
parent
e98ebcf0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
31 deletions
+24
-31
nni/compression/pytorch/utils/counter.py
nni/compression/pytorch/utils/counter.py
+14
-23
test/ut/sdk/test_compression_utils.py
test/ut/sdk/test_compression_utils.py
+10
-8
No files found.
nni/compression/pytorch/utils/counter.py
View file @
bcc55c52
...
@@ -121,15 +121,15 @@ class ModelProfiler:
...
@@ -121,15 +121,15 @@ class ModelProfiler:
return
self
.
_get_result
(
m
,
total_ops
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_bn
(
self
,
m
,
x
,
y
):
def
_count_bn
(
self
,
m
,
x
,
y
):
total_ops
=
2
*
x
[
0
].
numel
()
total_ops
=
2
*
x
[
0
]
[
0
]
.
numel
()
return
self
.
_get_result
(
m
,
total_ops
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_relu
(
self
,
m
,
x
,
y
):
def
_count_relu
(
self
,
m
,
x
,
y
):
total_ops
=
x
[
0
].
numel
()
total_ops
=
x
[
0
]
[
0
]
.
numel
()
return
self
.
_get_result
(
m
,
total_ops
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_avgpool
(
self
,
m
,
x
,
y
):
def
_count_avgpool
(
self
,
m
,
x
,
y
):
total_ops
=
y
.
numel
()
total_ops
=
y
[
0
]
.
numel
()
return
self
.
_get_result
(
m
,
total_ops
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_adap_avgpool
(
self
,
m
,
x
,
y
):
def
_count_adap_avgpool
(
self
,
m
,
x
,
y
):
...
@@ -137,27 +137,27 @@ class ModelProfiler:
...
@@ -137,27 +137,27 @@ class ModelProfiler:
total_add
=
int
(
torch
.
prod
(
kernel
))
total_add
=
int
(
torch
.
prod
(
kernel
))
total_div
=
1
total_div
=
1
kernel_ops
=
total_add
+
total_div
kernel_ops
=
total_add
+
total_div
num_elements
=
y
.
numel
()
num_elements
=
y
[
0
]
.
numel
()
total_ops
=
kernel_ops
*
num_elements
total_ops
=
kernel_ops
*
num_elements
return
self
.
_get_result
(
m
,
total_ops
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_upsample
(
self
,
m
,
x
,
y
):
def
_count_upsample
(
self
,
m
,
x
,
y
):
if
m
.
mode
==
'linear'
:
if
m
.
mode
==
'linear'
:
total_ops
=
y
.
nelement
()
*
5
# 2 muls + 3 add
total_ops
=
y
[
0
]
.
nelement
()
*
5
# 2 muls + 3 add
elif
m
.
mode
==
'bilinear'
:
elif
m
.
mode
==
'bilinear'
:
# https://en.wikipedia.org/wiki/Bilinear_interpolation
# https://en.wikipedia.org/wiki/Bilinear_interpolation
total_ops
=
y
.
nelement
()
*
11
# 6 muls + 5 adds
total_ops
=
y
[
0
]
.
nelement
()
*
11
# 6 muls + 5 adds
elif
m
.
mode
==
'bicubic'
:
elif
m
.
mode
==
'bicubic'
:
# https://en.wikipedia.org/wiki/Bicubic_interpolation
# https://en.wikipedia.org/wiki/Bicubic_interpolation
# Product matrix [4x4] x [4x4] x [4x4]
# Product matrix [4x4] x [4x4] x [4x4]
ops_solve_A
=
224
# 128 muls + 96 adds
ops_solve_A
=
224
# 128 muls + 96 adds
ops_solve_p
=
35
# 16 muls + 12 adds + 4 muls + 3 adds
ops_solve_p
=
35
# 16 muls + 12 adds + 4 muls + 3 adds
total_ops
=
y
.
nelement
()
*
(
ops_solve_A
+
ops_solve_p
)
total_ops
=
y
[
0
]
.
nelement
()
*
(
ops_solve_A
+
ops_solve_p
)
elif
m
.
mode
==
'trilinear'
:
elif
m
.
mode
==
'trilinear'
:
# https://en.wikipedia.org/wiki/Trilinear_interpolation
# https://en.wikipedia.org/wiki/Trilinear_interpolation
# can viewed as 2 bilinear + 1 linear
# can viewed as 2 bilinear + 1 linear
total_ops
=
y
.
nelement
()
*
(
13
*
2
+
5
)
total_ops
=
y
[
0
]
.
nelement
()
*
(
13
*
2
+
5
)
else
:
else
:
total_ops
=
0
total_ops
=
0
...
@@ -202,26 +202,16 @@ class ModelProfiler:
...
@@ -202,26 +202,16 @@ class ModelProfiler:
return
total_ops
return
total_ops
def
_count_rnn_cell
(
self
,
m
,
x
,
y
):
def
_count_rnn_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'rnn'
)
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'rnn'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_gru_cell
(
self
,
m
,
x
,
y
):
def
_count_gru_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'gru'
)
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'gru'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_lstm_cell
(
self
,
m
,
x
,
y
):
def
_count_lstm_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'lstm'
)
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'lstm'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_get_bsize_nsteps
(
self
,
m
,
x
):
def
_get_bsize_nsteps
(
self
,
m
,
x
):
...
@@ -243,18 +233,17 @@ class ModelProfiler:
...
@@ -243,18 +233,17 @@ class ModelProfiler:
hidden_size
=
m
.
hidden_size
hidden_size
=
m
.
hidden_size
num_layers
=
m
.
num_layers
num_layers
=
m
.
num_layers
batch_size
,
num_steps
=
self
.
_get_bsize_nsteps
(
m
,
x
)
_
,
num_steps
=
self
.
_get_bsize_nsteps
(
m
,
x
)
total_ops
=
self
.
_count_cell_flops
(
input_size
,
hidden_size
,
module_name
)
total_ops
=
self
.
_count_cell_flops
(
input_size
,
hidden_size
,
module_name
)
for
_
in
range
(
num_layers
-
1
):
for
_
in
range
(
num_layers
-
1
):
if
m
.
bidirectional
:
if
m
.
bidirectional
:
cell_flops
=
self
.
_count_cell_flops
(
hidden_size
*
2
,
hidden_size
,
module_name
)
*
2
cell_flops
=
self
.
_count_cell_flops
(
hidden_size
*
2
,
hidden_size
,
module_name
)
*
2
else
:
else
:
cell_flops
=
self
.
_count_cell_flops
(
hidden_size
,
hidden_size
,
module_name
)
cell_flops
=
self
.
_count_cell_flops
(
hidden_size
,
hidden_size
,
module_name
)
total_ops
+=
cell_flops
total_ops
+=
cell_flops
total_ops
*=
num_steps
total_ops
*=
num_steps
total_ops
*=
batch_size
return
total_ops
return
total_ops
def
_count_rnn
(
self
,
m
,
x
,
y
):
def
_count_rnn
(
self
,
m
,
x
,
y
):
...
@@ -272,7 +261,6 @@ class ModelProfiler:
...
@@ -272,7 +261,6 @@ class ModelProfiler:
return
self
.
_get_result
(
m
,
total_ops
)
return
self
.
_get_result
(
m
,
total_ops
)
def
count_module
(
self
,
m
,
x
,
y
,
name
):
def
count_module
(
self
,
m
,
x
,
y
,
name
):
# assume x is tuple of single tensor
# assume x is tuple of single tensor
result
=
self
.
ops
[
type
(
m
)](
m
,
x
,
y
)
result
=
self
.
ops
[
type
(
m
)](
m
,
x
,
y
)
...
@@ -337,6 +325,9 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
...
@@ -337,6 +325,9 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
according to its mask, and do not take the pruned input channels into consideration,
according to its mask, and do not take the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
so the calculated FLOPs will be larger than real number.
The FLOPs is counted "per sample", which means that input has a batch size larger than 1,
the calculated FLOPs should not differ from batch size of 1.
Parameters
Parameters
---------
---------
model : nn.Module
model : nn.Module
...
...
test/ut/sdk/test_compression_utils.py
View file @
bcc55c52
...
@@ -138,6 +138,7 @@ class AnalysisUtilsTest(TestCase):
...
@@ -138,6 +138,7 @@ class AnalysisUtilsTest(TestCase):
assert
b_index1
==
b_index2
assert
b_index1
==
b_index2
class
FlopsCounterTest
(
TestCase
):
def
test_flops_params
(
self
):
def
test_flops_params
(
self
):
class
Model1
(
nn
.
Module
):
class
Model1
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -171,14 +172,15 @@ class AnalysisUtilsTest(TestCase):
...
@@ -171,14 +172,15 @@ class AnalysisUtilsTest(TestCase):
x
=
self
.
conv2
(
x
)
x
=
self
.
conv2
(
x
)
return
x
return
x
flops
,
params
,
results
=
count_flops_params
(
Model1
(),
(
1
,
3
,
2
,
2
),
mode
=
'full'
,
verbose
=
False
)
for
bs
in
[
1
,
2
]:
flops
,
params
,
results
=
count_flops_params
(
Model1
(),
(
bs
,
3
,
2
,
2
),
mode
=
'full'
,
verbose
=
False
)
assert
(
flops
,
params
)
==
(
610
,
240
)
assert
(
flops
,
params
)
==
(
610
,
240
)
flops
,
params
,
results
=
count_flops_params
(
Model2
(),
(
1
,
3
,
2
,
2
),
verbose
=
False
)
flops
,
params
,
results
=
count_flops_params
(
Model2
(),
(
bs
,
3
,
2
,
2
),
verbose
=
False
)
assert
(
flops
,
params
)
==
(
560
,
50
)
assert
(
flops
,
params
)
==
(
560
,
50
)
from
torchvision.models
import
resnet50
from
torchvision.models
import
resnet50
flops
,
params
,
results
=
count_flops_params
(
resnet50
(),
(
1
,
3
,
224
,
224
),
verbose
=
False
)
flops
,
params
,
results
=
count_flops_params
(
resnet50
(),
(
bs
,
3
,
224
,
224
),
verbose
=
False
)
assert
(
flops
,
params
)
==
(
4089184256
,
25503912
)
assert
(
flops
,
params
)
==
(
4089184256
,
25503912
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment