Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7fd07766
"...composable_kernel_rocm.git" did not exist on "f8ca9048b3c23c487aa1fc4d1ceebc90403943a4"
Unverified
Commit
7fd07766
authored
Apr 21, 2021
by
colorjam
Committed by
GitHub
Apr 21, 2021
Browse files
Fix bug of FLOPs counter (#3497)
parent
638da0bd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
126 additions
and
11 deletions
+126
-11
nni/compression/pytorch/utils/counter.py
nni/compression/pytorch/utils/counter.py
+126
-11
No files found.
nni/compression/pytorch/utils/counter.py
View file @
7fd07766
...
@@ -7,6 +7,7 @@ from prettytable import PrettyTable
...
@@ -7,6 +7,7 @@ from prettytable import PrettyTable
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn.utils.rnn
import
PackedSequence
from
nni.compression.pytorch.compressor
import
PrunerModuleWrapper
from
nni.compression.pytorch.compressor
import
PrunerModuleWrapper
...
@@ -32,21 +33,27 @@ class ModelProfiler:
...
@@ -32,21 +33,27 @@ class ModelProfiler:
for reference, please see ``self.ops``.
for reference, please see ``self.ops``.
mode:
mode:
the mode of how to collect information. If the mode is set to `default`,
the mode of how to collect information. If the mode is set to `default`,
only the information of convolution
and linear
will be collected.
only the information of convolution
, linear and rnn modules
will be collected.
If the mode is set to `full`, other operations will also be collected.
If the mode is set to `full`, other operations will also be collected.
"""
"""
self
.
ops
=
{
self
.
ops
=
{
nn
.
Conv1d
:
self
.
_count_convNd
,
nn
.
Conv1d
:
self
.
_count_convNd
,
nn
.
Conv2d
:
self
.
_count_convNd
,
nn
.
Conv2d
:
self
.
_count_convNd
,
nn
.
Conv3d
:
self
.
_count_convNd
,
nn
.
Conv3d
:
self
.
_count_convNd
,
nn
.
Linear
:
self
.
_count_linear
nn
.
ConvTranspose1d
:
self
.
_count_convNd
,
nn
.
ConvTranspose2d
:
self
.
_count_convNd
,
nn
.
ConvTranspose3d
:
self
.
_count_convNd
,
nn
.
Linear
:
self
.
_count_linear
,
nn
.
RNNCell
:
self
.
_count_rnn_cell
,
nn
.
GRUCell
:
self
.
_count_gru_cell
,
nn
.
LSTMCell
:
self
.
_count_lstm_cell
,
nn
.
RNN
:
self
.
_count_rnn
,
nn
.
GRU
:
self
.
_count_gru
,
nn
.
LSTM
:
self
.
_count_lstm
}
}
self
.
_count_bias
=
False
self
.
_count_bias
=
False
if
mode
==
'full'
:
if
mode
==
'full'
:
self
.
ops
.
update
({
self
.
ops
.
update
({
nn
.
ConvTranspose1d
:
self
.
_count_convNd
,
nn
.
ConvTranspose2d
:
self
.
_count_convNd
,
nn
.
ConvTranspose3d
:
self
.
_count_convNd
,
nn
.
BatchNorm1d
:
self
.
_count_bn
,
nn
.
BatchNorm1d
:
self
.
_count_bn
,
nn
.
BatchNorm2d
:
self
.
_count_bn
,
nn
.
BatchNorm2d
:
self
.
_count_bn
,
nn
.
BatchNorm3d
:
self
.
_count_bn
,
nn
.
BatchNorm3d
:
self
.
_count_bn
,
...
@@ -86,7 +93,7 @@ class ModelProfiler:
...
@@ -86,7 +93,7 @@ class ModelProfiler:
def
_count_convNd
(
self
,
m
,
x
,
y
):
def
_count_convNd
(
self
,
m
,
x
,
y
):
cin
=
m
.
in_channels
cin
=
m
.
in_channels
kernel_ops
=
m
.
weight
.
size
()[
2
]
*
m
.
weight
.
size
()[
3
]
kernel_ops
=
torch
.
zeros
(
m
.
weight
.
size
()[
2
:]).
numel
()
output_size
=
torch
.
zeros
(
y
.
size
()[
2
:]).
numel
()
output_size
=
torch
.
zeros
(
y
.
size
()[
2
:]).
numel
()
cout
=
y
.
size
()[
1
]
cout
=
y
.
size
()[
1
]
...
@@ -156,13 +163,125 @@ class ModelProfiler:
...
@@ -156,13 +163,125 @@ class ModelProfiler:
return
self
.
_get_result
(
m
,
total_ops
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_cell_flops
(
self
,
input_size
,
hidden_size
,
cell_type
):
# h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
total_ops
=
hidden_size
*
(
input_size
+
hidden_size
)
+
hidden_size
if
self
.
_count_bias
:
total_ops
+=
hidden_size
*
2
if
cell_type
==
'rnn'
:
return
total_ops
if
cell_type
==
'gru'
:
# r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
# z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
# n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
total_ops
*=
3
# r hadamard : r * (~)
total_ops
+=
hidden_size
# h' = (1 - z) * n + z * h
# hadamard hadamard add
total_ops
+=
hidden_size
*
3
elif
cell_type
==
'lstm'
:
# i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
# f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
# o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
# g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
total_ops
*=
4
# c' = f * c + i * g
# hadamard hadamard add
total_ops
+=
hidden_size
*
3
# h' = o * \tanh(c')
total_ops
+=
hidden_size
return
total_ops
def
_count_rnn_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'rnn'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_gru_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'gru'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_lstm_cell
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_cell_flops
(
m
.
input_size
,
m
.
hidden_size
,
'lstm'
)
batch_size
=
x
[
0
].
size
(
0
)
total_ops
*=
batch_size
return
self
.
_get_result
(
m
,
total_ops
)
def
_get_bsize_nsteps
(
self
,
m
,
x
):
if
isinstance
(
x
[
0
],
PackedSequence
):
batch_size
=
torch
.
max
(
x
[
0
].
batch_sizes
)
num_steps
=
x
[
0
].
batch_sizes
.
size
(
0
)
else
:
if
m
.
batch_first
:
batch_size
=
x
[
0
].
size
(
0
)
num_steps
=
x
[
0
].
size
(
1
)
else
:
batch_size
=
x
[
0
].
size
(
1
)
num_steps
=
x
[
0
].
size
(
0
)
return
batch_size
,
num_steps
def
_count_rnn_module
(
self
,
m
,
x
,
y
,
module_name
):
input_size
=
m
.
input_size
hidden_size
=
m
.
hidden_size
num_layers
=
m
.
num_layers
batch_size
,
num_steps
=
self
.
_get_bsize_nsteps
(
m
,
x
)
total_ops
=
self
.
_count_cell_flops
(
input_size
,
hidden_size
,
module_name
)
for
_
in
range
(
num_layers
-
1
):
if
m
.
bidirectional
:
cell_flops
=
self
.
_count_cell_flops
(
hidden_size
*
2
,
hidden_size
,
module_name
)
*
2
else
:
cell_flops
=
self
.
_count_cell_flops
(
hidden_size
,
hidden_size
,
module_name
)
total_ops
+=
cell_flops
total_ops
*=
num_steps
total_ops
*=
batch_size
return
total_ops
def
_count_rnn
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_rnn_module
(
m
,
x
,
y
,
'rnn'
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_gru
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_rnn_module
(
m
,
x
,
y
,
'gru'
)
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_lstm
(
self
,
m
,
x
,
y
):
total_ops
=
self
.
_count_rnn_module
(
m
,
x
,
y
,
'lstm'
)
return
self
.
_get_result
(
m
,
total_ops
)
def
count_module
(
self
,
m
,
x
,
y
,
name
):
def
count_module
(
self
,
m
,
x
,
y
,
name
):
# assume x is tuple of single tensor
# assume x is tuple of single tensor
result
=
self
.
ops
[
type
(
m
)](
m
,
x
,
y
)
result
=
self
.
ops
[
type
(
m
)](
m
,
x
,
y
)
output_size
=
y
[
0
].
size
()
if
isinstance
(
y
,
tuple
)
else
y
.
size
()
total_result
=
{
total_result
=
{
'name'
:
name
,
'name'
:
name
,
'input_size'
:
tuple
(
x
[
0
].
size
()),
'input_size'
:
tuple
(
x
[
0
].
size
()),
'output_size'
:
tuple
(
y
.
size
()
),
'output_size'
:
tuple
(
output_
size
),
'module_type'
:
type
(
m
).
__name__
,
'module_type'
:
type
(
m
).
__name__
,
**
result
**
result
}
}
...
@@ -279,10 +398,6 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
...
@@ -279,10 +398,6 @@ def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
model
(
*
x
)
model
(
*
x
)
# restore origin status
# restore origin status
for
name
,
m
in
model
.
named_modules
():
if
hasattr
(
m
,
'weight_mask'
):
delattr
(
m
,
'weight_mask'
)
model
.
train
(
training
).
to
(
original_device
)
model
.
train
(
training
).
to
(
original_device
)
for
handler
in
handler_collection
:
for
handler
in
handler_collection
:
handler
.
remove
()
handler
.
remove
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment