Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
be7eee0c
Unverified
Commit
be7eee0c
authored
May 16, 2022
by
Ningxin Zheng
Committed by
GitHub
May 16, 2022
Browse files
enable customizing the replace function (#4826)
parent
98c1a77f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
5 deletions
+25
-5
nni/compression/pytorch/speedup/compressor.py
nni/compression/pytorch/speedup/compressor.py
+25
-5
No files found.
nni/compression/pytorch/speedup/compressor.py
View file @
be7eee0c
...
...
@@ -42,10 +42,23 @@ class ModelSpeedup:
the index of batch dimension in the dummy_input
confidence: the confidence coefficient of the sparsity inference. This value is
actually used as the batchsize of the dummy_input.
customized_replace_func: None/Dict
If `customized_replace_func` is not None, then we will use the given function to replace the
corresponding modules. The `key` of the dict is the opertor types and the `value`
is the replace function of corresponding opertor. The replace function should take
two input parameters, one is the original module, the second input parameter is tuple
of the input mask, output mask and weight mask. This replace function should prune the module
accordingly. Here is an example of the replace function(more examples can refer to compress_modules.py)::
def example_replace(ori_module, masks):
in_mask, out_mask, weight_mask = masks
# prune the ori_module to a new smaller module according to the mask
return new_small_module
"""
def
__init__
(
self
,
model
,
dummy_input
,
masks_file
,
map_location
=
None
,
batch_dim
=
0
,
confidence
=
8
):
batch_dim
=
0
,
confidence
=
8
,
customized_replace_func
=
None
):
assert
confidence
>
1
# The auto inference will change the values of the parameters in the model
# so we need make a copy before the mask inference
...
...
@@ -53,7 +66,8 @@ class ModelSpeedup:
self
.
bound_model
=
model
self
.
inferred_masks
=
dict
()
# key: module_name, value: ModuleMasks
self
.
batch_dim
=
batch_dim
self
.
dummy_input
,
self
.
device
=
self
.
_random_model_input
(
dummy_input
,
confidence
,
batch_dim
)
self
.
dummy_input
,
self
.
device
=
self
.
_random_model_input
(
dummy_input
,
confidence
,
batch_dim
)
self
.
torch_graph
=
build_module_graph
(
model
,
self
.
dummy_input
)
# dict object to save the auto inferences objects of the submodules
self
.
auto_inferences
=
{}
...
...
@@ -75,6 +89,7 @@ class ModelSpeedup:
self
.
constant
=
{}
# self.internal_result save the internal output of the submodules
self
.
internal_result
=
{}
self
.
customized_replace_func
=
customized_replace_func
if
customized_replace_func
is
not
None
else
{}
def
_random_model_input
(
self
,
dummy_input
,
confidence
,
batch_dim
):
"""
...
...
@@ -284,7 +299,8 @@ class ModelSpeedup:
else
:
last_output
.
grad
=
tin
.
grad
else
:
_logger
.
warning
(
'Note: %s does not have corresponding mask inference object'
,
node
.
name
)
_logger
.
warning
(
'Note: %s does not have corresponding mask inference object'
,
node
.
name
)
def
_vnode_to_value
(
self
,
c_node
):
"""
...
...
@@ -408,6 +424,7 @@ class ModelSpeedup:
method is shutdown, in the future, we will merge these two methods into a graph
pass which is used to resolve the mask conflict.
"""
def
__init__
(
self
,
ori_module
,
reindex_dim
,
reindex
):
super
(
ReindexModule
,
self
).
__init__
()
self
.
ori_module
=
ori_module
...
...
@@ -441,12 +458,15 @@ class ModelSpeedup:
super_module
,
leaf_module
=
get_module_by_name
(
self
.
bound_model
,
g_node
.
name
)
m_type
=
g_node
.
op_type
if
not
m_type
in
replace_module
:
if
(
not
m_type
in
replace_module
)
and
(
m_type
not
in
self
.
customized_replace_func
)
:
raise
RuntimeError
(
"Has not supported replacing the module: `{}`"
.
format
(
m_type
))
_logger
.
info
(
"replace module (name: %s, op_type: %s)"
,
g_node
.
name
,
m_type
)
compressed_module
=
replace_module
[
m_type
](
replace_function
=
replace_module
[
m_type
]
if
m_type
in
self
.
customized_replace_func
:
replace_function
=
self
.
customized_replace_func
[
m_type
]
compressed_module
=
replace_function
(
leaf_module
,
auto_infer
.
get_masks
())
new_submodule
=
compressed_module
if
reindex_dim
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment