Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7fd07766
Unverified
Commit
7fd07766
authored
Apr 21, 2021
by
colorjam
Committed by
GitHub
Apr 21, 2021
Browse files
Fix bug of FLOPs counter (#3497)
parent
638da0bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
126 additions
and
11 deletions
+126
-11
nni/compression/pytorch/utils/counter.py
nni/compression/pytorch/utils/counter.py
+126
-11
No files found.
nni/compression/pytorch/utils/counter.py
View file @
7fd07766
...
...
@@ -7,6 +7,7 @@ from prettytable import PrettyTable
import
torch
import
torch.nn
as
nn
from
torch.nn.utils.rnn
import
PackedSequence
from
nni.compression.pytorch.compressor
import
PrunerModuleWrapper
...
...
@@ -32,21 +33,27 @@ class ModelProfiler:
for reference, please see ``self.ops``.
mode:
the mode of how to collect information. If the mode is set to `default`,
only the information of convolution
and linear
will be collected.
only the information of convolution
, linear and rnn modules
will be collected.
If the mode is set to `full`, other operations will also be collected.
"""
self
.
ops
=
{
nn
.
Conv1d
:
self
.
_count_convNd
,
nn
.
Conv2d
:
self
.
_count_convNd
,
nn
.
Conv3d
:
self
.
_count_convNd
,
nn
.
Linear
:
self
.
_count_linear
nn
.
ConvTranspose1d
:
self
.
_count_convNd
,
nn
.
ConvTranspose2d
:
self
.
_count_convNd
,
nn
.
ConvTranspose3d
:
self
.
_count_convNd
,
nn
.
Linear
:
self
.
_count_linear
,
nn
.
RNNCell
:
self
.
_count_rnn_cell
,
nn
.
GRUCell
:
self
.
_count_gru_cell
,
nn
.
LSTMCell
:
self
.
_count_lstm_cell
,
nn
.
RNN
:
self
.
_count_rnn
,
nn
.
GRU
:
self
.
_count_gru
,
nn
.
LSTM
:
self
.
_count_lstm
}
self
.
_count_bias
=
False
if
mode
==
'full'
:
self
.
ops
.
update
({
nn
.
ConvTranspose1d
:
self
.
_count_convNd
,
nn
.
ConvTranspose2d
:
self
.
_count_convNd
,
nn
.
ConvTranspose3d
:
self
.
_count_convNd
,
nn
.
BatchNorm1d
:
self
.
_count_bn
,
nn
.
BatchNorm2d
:
self
.
_count_bn
,
nn
.
BatchNorm3d
:
self
.
_count_bn
,
...
...
@@ -86,7 +93,7 @@ class ModelProfiler:
def
_count_convNd
(
self
,
m
,
x
,
y
):
cin
=
m
.
in_channels
kernel_ops
=
m
.
weight
.
size
()[
2
]
*
m
.
weight
.
size
()[
3
]
kernel_ops
=
torch
.
zeros
(
m
.
weight
.
size
()[
2
:]).
numel
()
output_size
=
torch
.
zeros
(
y
.
size
()[
2
:]).
numel
()
cout
=
y
.
size
()[
1
]
...
...
@@ -156,13 +163,125 @@ class ModelProfiler:
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_cell_flops
(
self
,
input_size
,
hidden_size
,
cell_type
):
# h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
total_ops
=
hidden_size
*
(
input_size
+
hidden_size
)
+
hidden_size
if
self
.
_count_bias
:
total_ops
+=
hidden_size
*
2
if
cell_type
==
'rnn'
:
return
total_ops
if
cell_type
==
'gru'
:
# r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
# z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
# n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
total_ops
*=
3
# r hadamard : r * (~)
total_ops
+=
hidden_size
# h' = (1 - z) * n + z * h
# hadamard hadamard add
total_ops
+=
hidden_size
*
3
elif
cell_type
==
'lstm'
:
# i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
# f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
# o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
# g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
total_ops
*=
4
# c' = f * c + i * g
# hadamard hadamard add
total_ops
+=
hidden_size
*
3
# h' = o * \tanh(c')
total_ops
+=
hidden_size
return
total_ops
def
_count_rnn_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'rnn'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_gru_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'gru'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_lstm_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'lstm'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
def
_get_bsize_nsteps
(
self
,
m
,
x
):
if
isinstance
(
x
[
0
],
PackedSequence
):
batch_size
=
torch
.
max
(
x
[
0
].
batch_sizes
)
num_steps
=
x
[
0
].
batch_sizes
.
size
(
0
)
else
:
if
m
.
batch_first
:
batch_size
=
x
[
0
].
size
(
0
)
num_steps
=
x
[
0
].
size
(
1
)
else
:
batch_size
=
x
[
0
].
size
(
1
)
num_steps
=
x
[
0
].
size
(
0
)
return
batch_size
,
num_steps
def
_count_rnn_module
(
self
,
m
,
x
,
y
,
module_name
):
input_size
=
m
.
input_size
hidden_size
=
m
.
hidden_size
num_layers
=
m
.
num_layers
batch_size
,
num_steps
=
self
.
_get_bsize_nsteps
(
m
,
x
)
total_ops
=
self
.
_count_cell_flops
(
input_size
,
hidden_size
,
module_name
)
for
_
in
range
(
num_layers
-
1
):
if
m
.
bidirectional
:
cell_flops
=
self
.
_count_cell_flops
(
hidden_size
*
2
,
hidden_size
,
module_name
)
*
2
else
:
cell_flops
=
self
.
_count_cell_flops
(
hidden_size
,
hidden_size
,
module_name
)
total_ops
+=
cell_flops
total_ops
*=
num_steps
total_ops
*=
batch_size
return
total_ops
def
_count_rnn
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_rnn_module
(
m
,
x
,
y
,
'rnn'
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_gru
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_rnn_module
(
m
,
x
,
y
,
'gru'
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_lstm
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_rnn_module
(
m
,
x
,
y
,
'lstm'
)
return
self
.
_get_result
(
m
,
total_ops
)
def
count_module
(
self
,
m
,
x
,
y
,
name
):
# assume x is tuple of single tensor
result
=
self
.
ops
[
type
(
m
)](
m
,
x
,
y
)
output_size
=
y
[
0
].
size
()
if
isinstance
(
y
,
tuple
)
else
y
.
size
()
total_result
=
{
'name'
:
name
,
'input_size'
:
tuple
(
x
[
0
].
size
()),
'output_size'
:
tuple
(
y
.
size
()
),
'output_size'
:
tuple
(
output_
size
),
'module_type'
:
type
(
m
).
__name__
,
**
result
}
...
...
@@ -279,10 +398,6 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
model
(
*
x
)
# restore origin status
for
name
,
m
in
model
.
named_modules
():
if
hasattr
(
m
,
'weight_mask'
):
delattr
(
m
,
'weight_mask'
)
model
.
train
(
training
).
to
(
original_device
)
for
handler
in
handler_collection
:
handler
.
remove
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment