Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d1c63562
"src/vscode:/vscode.git/clone" did not exist on "92121fc66819444daba11bcb625826497a36c514"
Unverified
Commit
d1c63562
authored
Sep 21, 2020
by
colorjam
Committed by
GitHub
Sep 21, 2020
Browse files
Add custom op support of counter (#2795)
parent
f892ed67
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
11 deletions
+14
-11
src/sdk/pynni/nni/compression/torch/utils/counter.py
src/sdk/pynni/nni/compression/torch/utils/counter.py
+14
-11
No files found.
src/sdk/pynni/nni/compression/torch/utils/counter.py
View file @
d1c63562
...
...
@@ -12,7 +12,7 @@ except Exception as e:
raise
def
count_flops_params
(
model
:
nn
.
Module
,
input_size
,
verbose
=
True
):
def
count_flops_params
(
model
:
nn
.
Module
,
input_size
,
custom_ops
=
None
,
verbose
=
True
):
"""
Count FLOPs and Params of the given model.
This function would identify the mask on the module
...
...
@@ -28,7 +28,10 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
target model.
input_size: list, tuple
the input shape of data
custom_ops: dict
a mapping of (module: custom operation)
the custom operation will overwrite the default operation.
for reference, please see ``custom_mask_ops``.
Returns
-------
...
...
@@ -44,11 +47,14 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
inputs
=
torch
.
randn
(
input_size
).
to
(
device
)
hook_module_list
=
[]
if
custom_ops
is
None
:
custom_ops
=
{}
custom_mask_ops
.
update
(
custom_ops
)
prev_m
=
None
for
m
in
model
.
modules
():
weight_mask
=
None
m_type
=
type
(
m
)
if
m_type
in
custom_ops
:
if
m_type
in
custom_
mask_
ops
:
if
isinstance
(
prev_m
,
PrunerModuleWrapper
):
weight_mask
=
prev_m
.
weight_mask
...
...
@@ -56,7 +62,7 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
hook_module_list
.
append
(
m
)
prev_m
=
m
flops
,
params
=
profile
(
model
,
inputs
=
(
inputs
,
),
custom_ops
=
custom_ops
,
verbose
=
verbose
)
flops
,
params
=
profile
(
model
,
inputs
=
(
inputs
,
),
custom_ops
=
custom_
mask_
ops
,
verbose
=
verbose
)
for
m
in
hook_module_list
:
...
...
@@ -74,7 +80,6 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
def
count_convNd_mask
(
m
,
x
,
y
):
"""
The forward hook to count FLOPs and Parameters of convolution operation.
Parameters
----------
m : torch.nn.Module
...
...
@@ -101,7 +106,6 @@ def count_convNd_mask(m, x, y):
def
count_linear_mask
(
m
,
x
,
y
):
"""
The forward hook to count FLOPs and Parameters of linear transformation.
Parameters
----------
m : torch.nn.Module
...
...
@@ -111,22 +115,21 @@ def count_linear_mask(m, x, y):
y : torch.Tensor
output data
"""
output_channel
=
y
.
size
()[
1
]
output_size
=
torch
.
zeros
(
y
.
size
()[
2
:]).
numel
()
output_channel
=
y
.
numel
()
bias_flops
=
1
if
m
.
bias
is
not
None
else
0
if
m
.
weight_mask
is
not
None
:
output_channel
=
m
.
weight_mask
.
sum
()
//
m
.
in_features
total_ops
=
output_channel
*
output_size
*
(
m
.
in_features
+
bias_flops
)
total_ops
=
output_channel
*
(
m
.
in_features
+
bias_flops
)
m
.
total_ops
+=
torch
.
DoubleTensor
([
int
(
total_ops
)])
custom_ops
=
{
custom_
mask_
ops
=
{
nn
.
Conv1d
:
count_convNd_mask
,
nn
.
Conv2d
:
count_convNd_mask
,
nn
.
Conv3d
:
count_convNd_mask
,
nn
.
Linear
:
count_linear_mask
,
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment