Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5c861676
Unverified
Commit
5c861676
authored
May 15, 2020
by
chicm-ms
Committed by
GitHub
May 15, 2020
Browse files
Graph torch14 refactor (#2384)
parent
ac238f01
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1539 additions
and
551 deletions
+1539
-551
azure-pipelines.yml
azure-pipelines.yml
+46
-25
docs/en_US/Compressor/ModelSpeedup.md
docs/en_US/Compressor/ModelSpeedup.md
+1
-1
src/sdk/pynni/nni/_graph_utils.py
src/sdk/pynni/nni/_graph_utils.py
+490
-0
src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
...k/pynni/nni/compression/speedup/torch/compress_modules.py
+1
-0
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
+13
-389
src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py
src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py
+10
-0
src/sdk/pynni/nni/nas/pytorch/_graph_utils.py
src/sdk/pynni/nni/nas/pytorch/_graph_utils.py
+0
-134
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+2
-2
src/sdk/pynni/tests/expect/test_graph_module1.expect
src/sdk/pynni/tests/expect/test_graph_module1.expect
+152
-0
src/sdk/pynni/tests/expect/test_graph_module2.expect
src/sdk/pynni/tests/expect/test_graph_module2.expect
+309
-0
src/sdk/pynni/tests/expect/test_graph_module3.expect
src/sdk/pynni/tests/expect/test_graph_module3.expect
+250
-0
src/sdk/pynni/tests/test_graph_utils.py
src/sdk/pynni/tests/test_graph_utils.py
+158
-0
src/sdk/pynni/tests/test_model_speedup.py
src/sdk/pynni/tests/test_model_speedup.py
+107
-0
No files found.
azure-pipelines.yml
View file @
5c861676
# Azure hosted agents specification:
# https://docs.microsoft.com/en-us/azure/devops/pipelines/agents/hosted?view=azure-devops
jobs
:
jobs
:
-
job
:
'
basic_test_pr_ubuntu
'
-
job
:
'
ubuntu_1804_python36
'
pool
:
pool
:
vmImage
:
'
Ubuntu
16.04'
vmImage
:
'
Ubuntu
18.04'
strategy
:
matrix
:
Python36
:
PYTHON_VERSION
:
'
3.6'
steps
:
steps
:
-
script
:
|
-
script
:
|
...
@@ -26,9 +25,8 @@ jobs:
...
@@ -26,9 +25,8 @@ jobs:
yarn eslint
yarn eslint
displayName
:
'
Run
eslint'
displayName
:
'
Run
eslint'
-
script
:
|
-
script
:
|
python3 -m pip install torch==1.2.0 --user
python3 -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install torchvision==0.4.0 --user
python3 -m pip install tensorflow==1.15.2 --user
python3 -m pip install tensorflow==1.13.1 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install gym onnx --user
python3 -m pip install gym onnx --user
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 --user
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 --user
...
@@ -63,13 +61,41 @@ jobs:
...
@@ -63,13 +61,41 @@ jobs:
sphinx-build -M html . _build -W
sphinx-build -M html . _build -W
displayName
:
'
Sphinx
Documentation
Build
check'
displayName
:
'
Sphinx
Documentation
Build
check'
-
job
:
'
basic_test_pr_macOS'
-
job
:
'
ubuntu_1604_python35_legacy_torch'
pool
:
vmImage
:
'
Ubuntu
16.04'
steps
:
-
script
:
|
python3 -m pip install --upgrade pip setuptools --user
python3 -m pip install coverage --user
echo "##vso[task.setvariable variable=PATH]${HOME}/.local/bin:${PATH}"
displayName
:
'
Install
python
tools'
-
script
:
|
source install.sh
displayName
:
'
Install
nni
toolkit
via
source
code'
-
script
:
|
python3 -m pip install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorflow==1.15.2 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install gym onnx --user
sudo apt-get install swig -y
nnictl package install --name=SMAC
nnictl package install --name=BOHB
displayName
:
'
Install
dependencies'
-
script
:
|
cd test
source scripts/unittest.sh
displayName
:
'
Unit
test'
-
script
:
|
cd test
python3 nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName
:
'
Simple
test'
-
job
:
'
macos_1015_python37'
pool
:
pool
:
vmImage
:
'
macOS-10.15'
vmImage
:
'
macOS-10.15'
strategy
:
matrix
:
Python36
:
PYTHON_VERSION
:
'
3.6'
steps
:
steps
:
-
script
:
python3 -m pip install --upgrade pip setuptools
-
script
:
python3 -m pip install --upgrade pip setuptools
...
@@ -79,9 +105,9 @@ jobs:
...
@@ -79,9 +105,9 @@ jobs:
echo "##vso[task.setvariable variable=PATH]${HOME}/Library/Python/3.7/bin:${PATH}"
echo "##vso[task.setvariable variable=PATH]${HOME}/Library/Python/3.7/bin:${PATH}"
displayName
:
'
Install
nni
toolkit
via
source
code'
displayName
:
'
Install
nni
toolkit
via
source
code'
-
script
:
|
-
script
:
|
pyt
hon3 -m pip install torch==1.2.0 --user
#
pyt
orch Mac binary does not support CUDA, default is cpu version
python3 -m pip install torchvision==0.
4
.0 --user
python3 -m pip install torchvision==0.
6.0 torch==1.5
.0 --user
python3 -m pip install tensorflow==1.1
3.1
--user
python3 -m pip install tensorflow==1.1
5.2
--user
brew install swig@3
brew install swig@3
rm /usr/local/bin/swig
rm /usr/local/bin/swig
ln -s /usr/local/opt/swig\@3/bin/swig /usr/local/bin/swig
ln -s /usr/local/opt/swig\@3/bin/swig /usr/local/bin/swig
...
@@ -96,13 +122,9 @@ jobs:
...
@@ -96,13 +122,9 @@ jobs:
python3 nni_test/nnitest/run_tests.py --config config/pr_tests.yml
python3 nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName
:
'
Simple
test'
displayName
:
'
Simple
test'
-
job
:
'
basic_test_pr_Windows
'
-
job
:
'
win2016_python37
'
pool
:
pool
:
vmImage
:
'
vs2017-win2016'
vmImage
:
'
vs2017-win2016'
strategy
:
matrix
:
Python36
:
PYTHON_VERSION
:
'
3.6'
steps
:
steps
:
-
script
:
|
-
script
:
|
...
@@ -111,9 +133,8 @@ jobs:
...
@@ -111,9 +133,8 @@ jobs:
-
script
:
|
-
script
:
|
python -m pip install scikit-learn==0.20.0 --user
python -m pip install scikit-learn==0.20.0 --user
python -m pip install keras==2.1.6 --user
python -m pip install keras==2.1.6 --user
python -m pip install torch===1.2.0 torchvision===0.4.1 -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install torchvision --user
python -m pip install tensorflow==1.15.2 --user
python -m pip install tensorflow==1.13.1 --user
displayName
:
'
Install
dependencies'
displayName
:
'
Install
dependencies'
-
script
:
|
-
script
:
|
cd test
cd test
...
...
docs/en_US/Compressor/ModelSpeedup.md
View file @
5c861676
...
@@ -34,7 +34,7 @@ print('elapsed time: ', time.time() - start)
...
@@ -34,7 +34,7 @@ print('elapsed time: ', time.time() - start)
```
```
For complete examples please refer to
[
the code
](
https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py
)
For complete examples please refer to
[
the code
](
https://github.com/microsoft/nni/tree/master/examples/model_compress/model_speedup.py
)
NOTE: The current implementation
only w
or
k
s
on t
orch 1.3.1
and torchvision 0.4.2
NOTE: The current implementation
supp
or
t
s
PyT
orch 1.3.1
or newer.
## Limitations
## Limitations
...
...
src/sdk/pynni/nni/_graph_utils.py
0 → 100644
View file @
5c861676
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
queue
import
re
from
collections
import
defaultdict
import
torch
from
torch.utils.tensorboard._pytorch_graph
import
NodePy
,
NodePyIO
,
NodePyOP
,
GraphPy
CLASSTYPE_KIND
=
'ClassType'
GETATTR_KIND
=
'prim::GetAttr'
_logger
=
logging
.
getLogger
(
__name__
)
def
build_module_graph
(
model
,
dummy_input
):
return
TorchModuleGraph
(
model
,
dummy_input
)
def
build_graph
(
model
,
dummy_input
,
verbose
=
False
):
g
=
TorchProtoGraph
(
model
,
dummy_input
,
verbose
)
return
g
.
graph_def
,
g
.
stepstats
def
parse_traced_name
(
module_name
):
prefix
=
'TracedModule['
suffix
=
']'
if
module_name
.
startswith
(
prefix
)
and
module_name
.
endswith
(
suffix
):
module_name
=
module_name
[
len
(
prefix
):
-
len
(
suffix
)]
return
module_name
class
TorchGraph
:
"""
This class is to extract pytorch model topology graph by tracing
"""
def
__init__
(
self
,
model
,
dummy_input
):
"""
Parameters
----------
model : pytorch model
The model user wants to speed up
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
"""
assert
torch
.
__version__
>=
'1.3.1'
self
.
bound_model
=
model
self
.
_trace
(
model
,
dummy_input
)
def
_trace
(
self
,
model
,
dummy_input
):
with
torch
.
onnx
.
set_training
(
model
,
False
):
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
class
TorchProtoGraph
(
TorchGraph
):
"""
Generates model graph for pytorch models in protobuf, this implementation is borrowed from pytorch v1.4.0,
and fixed following issues:
https://github.com/pytorch/pytorch/issues/33691
https://github.com/pytorch/pytorch/issues/33670
"""
def
__init__
(
self
,
model
,
dummy_input
,
verbose
=
False
):
super
().
__init__
(
model
,
dummy_input
)
from
tensorboard.compat.proto.config_pb2
import
RunMetadata
from
tensorboard.compat.proto.graph_pb2
import
GraphDef
from
tensorboard.compat.proto.step_stats_pb2
import
StepStats
,
DeviceStepStats
from
tensorboard.compat.proto.versions_pb2
import
VersionDef
list_of_nodes
=
self
.
parse
(
self
.
trace
.
graph
,
self
.
trace
,
dummy_input
)
if
verbose
:
print
(
self
.
trace
.
graph
)
self
.
stepstats
=
RunMetadata
(
step_stats
=
StepStats
(
dev_stats
=
[
DeviceStepStats
(
device
=
"/device:CPU:0"
)]))
self
.
graph_def
=
GraphDef
(
node
=
list_of_nodes
,
versions
=
VersionDef
(
producer
=
22
))
def
parse
(
self
,
graph
,
trace
,
args
=
None
,
omit_useless_nodes
=
True
):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
nodes_py
=
GraphPy
()
for
node
in
graph
.
inputs
():
if
omit_useless_nodes
:
if
not
node
.
uses
():
# number of user of the node (= number of outputs/ fanout)
continue
if
node
.
type
().
kind
()
!=
CLASSTYPE_KIND
:
nodes_py
.
append
(
NodePyIO
(
node
,
'input'
))
attr_to_scope
=
dict
()
node_to_name
=
lambda
d
:
str
(
d
).
split
(
":"
)[
0
].
strip
()
for
node
in
graph
.
nodes
():
if
node
.
kind
()
==
GETATTR_KIND
:
attr_name
=
node
.
s
(
'name'
)
node_name
=
node_to_name
(
node
)
parent
=
node
.
input
().
node
()
if
parent
.
kind
()
==
GETATTR_KIND
:
# If the parent node is not the top-level "self" node
parent_scope
=
attr_to_scope
[
node_to_name
(
parent
)]
attr_scope
=
parent_scope
.
split
(
'/'
)[
-
1
]
attr_to_scope
[
node_name
]
=
'{}/{}.{}'
.
format
(
parent_scope
,
attr_scope
,
attr_name
)
else
:
attr_to_scope
[
node_name
]
=
'__module.{}'
.
format
(
attr_name
)
# We don't need classtype nodes; scope will provide this information
if
node
.
output
().
type
().
kind
()
!=
CLASSTYPE_KIND
:
node_py
=
NodePyOP
(
node
)
node_py
.
scopeName
=
attr_to_scope
[
node_name
]
nodes_py
.
append
(
node_py
)
else
:
nodes_py
.
append
(
NodePyOP
(
node
))
for
i
,
node
in
enumerate
(
graph
.
outputs
()):
# Create sink nodes for output ops
node_py
=
NodePyIO
(
node
,
'output'
)
node_py
.
debugName
=
"output.{}"
.
format
(
i
+
1
)
node_py
.
inputs
=
[
node
.
debugName
()]
nodes_py
.
append
(
node_py
)
alias_to_name
=
dict
()
base_name
=
parse_traced_name
(
trace
.
_name
)
for
name
,
module
in
trace
.
named_modules
(
prefix
=
'__module'
):
mod_name
=
parse_traced_name
(
module
.
_name
)
attr_name
=
name
.
split
(
'.'
)[
-
1
]
alias_to_name
[
name
]
=
'{}[{}]'
.
format
(
mod_name
,
attr_name
)
for
node
in
nodes_py
.
nodes_op
:
module_aliases
=
node
.
scopeName
.
split
(
'/'
)[
-
1
].
split
(
'.'
)
module_name
=
''
for
i
,
alias
in
enumerate
(
module_aliases
):
if
i
==
0
:
module_name
=
alias
node
.
scopeName
=
base_name
else
:
module_name
+=
'.'
+
alias
node
.
scopeName
+=
'/'
+
(
alias_to_name
[
module_name
]
if
module_name
in
alias_to_name
else
alias
)
nodes_py
.
populate_namespace_from_OP_to_IO
()
return
nodes_py
.
to_proto
()
class
NodePyGroup
(
NodePy
):
"""
This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph,
there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
"""
def
__init__
(
self
,
name
,
node_type
,
op_type
,
node_cpps
,
inputs
=
None
,
outputs
=
None
):
"""
Parameters:
-----------
name: str
node name, such as `conv1`, `backbone.classifier`
node_type: str
`module` or `func`
op_type: str
operation type, such as `Conv2d`, `aten::view`
node_cpps: list of torch._C.Node
jit trace nodes which are included in this new node
inputs: list of str
All the inputs of this node, each element is debugName of one input
outputs: list of str
All the outputs of this node, each element is debugName of one output
"""
super
(
NodePyGroup
,
self
).
__init__
(
name
,
[])
self
.
node_cpps
=
node_cpps
self
.
name
=
name
self
.
op_type
=
op_type
self
.
type
=
node_type
self
.
nodes
=
[]
self
.
auxiliary
=
None
self
.
add_nodes
(
node_cpps
)
self
.
inputs
=
inputs
self
.
outputs
=
outputs
def
add_nodes
(
self
,
node_cpps
):
for
node_cpp
in
node_cpps
:
nodepy
=
NodePyOP
(
node_cpp
)
nodepy
.
name
=
str
(
node_cpp
).
split
(
':'
)[
0
].
strip
().
replace
(
'%'
,
''
)
self
.
nodes
.
append
(
nodepy
)
def
sub_node_names
(
self
):
return
[
x
.
name
for
x
in
self
.
nodes
]
def
__repr__
(
self
):
return
'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'
.
format
(
self
.
name
,
self
.
type
,
self
.
op_type
,
self
.
sub_node_names
(),
self
.
inputs
,
self
.
outputs
,
self
.
auxiliary
)
class
TorchModuleGraph
(
TorchGraph
):
"""
Generates model graph, each node is created from single or multiple jit trace nodes.
"""
def
__init__
(
self
,
model
,
dummy_input
):
super
().
__init__
(
model
,
dummy_input
)
self
.
global_count
=
0
self
.
name_to_node
,
self
.
input_to_node
,
self
.
output_to_node
=
self
.
_build_graph
()
def
_expand_non_prim_node
(
self
,
node
,
nodes
,
input_to_node
,
output_to_node
):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
trivial op which are label by ```prim::```, some of them are not such ops which is call
non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct
a node.
Parameters
----------
node : trace graph node
The non-prim node to expand
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
Returns
-------
node
the expanded non-prim node
"""
# TODO: scope name could be empty
node_name
=
'.'
.
join
([
self
.
_get_module_name
(
node
.
scopeName
()),
node
.
kind
(),
str
(
self
.
global_count
)])
_logger
.
debug
(
"expand non-prim node, node name: %s"
,
node_name
)
self
.
global_count
+=
1
op_type
=
node
.
kind
()
node_group
=
[
node
]
inputs
=
list
()
outputs
=
list
()
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
while
not
node_queue
.
empty
():
curr_node
=
node_queue
.
get
()
for
_input
in
curr_node
.
inputs
():
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
predecessor_node
=
output_to_node
[
input_name
]
if
predecessor_node
.
kind
().
startswith
(
'prim::'
):
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
else
:
inputs
.
append
(
input_name
)
else
:
inputs
.
append
(
input_name
)
for
output
in
node
.
outputs
():
outputs
.
append
(
output
.
debugName
())
nodepy
=
NodePyGroup
(
node_name
,
'func'
,
op_type
,
node_group
,
inputs
=
inputs
,
outputs
=
outputs
)
return
nodepy
def
_build_module_node_group
(
self
,
module_name
,
op_type
,
node_cpps
,
input_to_node
,
output_to_node
):
graph
=
self
.
trace
.
graph
inputs
,
outputs
=
[],
[]
for
n
in
node_cpps
:
for
i
in
n
.
inputs
():
name
=
i
.
debugName
()
if
not
name
in
output_to_node
and
i
in
graph
.
inputs
():
inputs
.
append
(
name
)
elif
output_to_node
[
name
]
not
in
node_cpps
:
inputs
.
append
(
name
)
for
o
in
n
.
outputs
():
name
=
o
.
debugName
()
if
not
name
in
input_to_node
and
o
in
graph
.
outputs
():
outputs
.
append
(
name
)
elif
input_to_node
[
name
]
not
in
node_cpps
:
outputs
.
append
(
name
)
return
NodePyGroup
(
module_name
,
'module'
,
op_type
,
node_cpps
,
inputs
,
outputs
)
def
_extract_shape_info
(
self
,
node
):
"""
Extract the shape information of ```aten::view``` node
Parameters
----------
node : trace graph node
It should be ```aten::view``` node
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
t_input
=
None
for
_input
in
node
.
inputs
():
t_input
=
_input
break
t_output
=
node
.
output
()
assert
isinstance
(
t_input
.
type
(),
torch
.
_C
.
TensorType
)
assert
isinstance
(
t_output
.
type
(),
torch
.
_C
.
TensorType
)
in_shape
=
t_input
.
type
().
sizes
()
out_shape
=
t_output
.
type
().
sizes
()
return
{
'in_shape'
:
in_shape
,
'out_shape'
:
out_shape
}
def
_extract_leaf_modules
(
self
):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Returns
-------
list
a list of scope name of all the leaf modules
"""
def
is_parent
(
name1
,
name2
):
"""
check if name1 is parent node of name2, for example:
name1: aa.bb, name2: aa.bb.cc, return True
name1: aa.b, name2: aa.bb, return False
"""
parts1
,
parts2
=
name1
.
split
(
'.'
),
name2
.
split
(
'.'
)
if
len
(
parts1
)
>=
len
(
parts2
):
return
False
for
i
in
range
(
len
(
parts1
)):
if
parts2
[
i
]
!=
parts1
[
i
]:
return
False
return
True
module_names
=
sorted
([
x
[
0
]
for
x
in
self
.
trace
.
named_modules
()
if
x
[
0
]])
leaf_nodes
=
[]
for
i
,
name
in
enumerate
(
module_names
):
if
i
+
1
>=
len
(
module_names
)
or
not
is_parent
(
name
,
module_names
[
i
+
1
]):
leaf_nodes
.
append
(
name
)
return
leaf_nodes
def
_get_module_name
(
self
,
scope_name
):
"""
Retrieve module name from scope name.
Parameters:
-----------
scope_name: str
scope_name of a graph node, for example:
for pytorch 1.3.1: MyModel/BackboneModel[backbone]/Conv2d[conv2]
for pytorch 1.4.0: __module.backbone/__module.backbone.conv2
Returns:
-------
str
module name, such as backbone.conv2
"""
if
torch
.
__version__
>=
'1.4.0'
:
return
scope_name
.
split
(
'/'
)[
-
1
].
replace
(
'__module.'
,
''
)
else
:
return
'.'
.
join
(
re
.
findall
(
r
'\[(.*?)\]'
,
scope_name
))
def
_build_index
(
self
,
nodes_op
):
name_to_node
=
dict
()
input_to_node
=
defaultdict
(
list
)
output_to_node
=
dict
()
for
node
in
nodes_op
:
name_to_node
[
node
.
name
]
=
node
for
_input
in
node
.
inputs
:
input_to_node
[
_input
].
append
(
node
)
for
output
in
node
.
outputs
:
assert
not
output
in
output_to_node
,
\
"One output cannot be generated by multiple nodes"
output_to_node
[
output
]
=
node
return
name_to_node
,
input_to_node
,
output_to_node
def
_build_graph
(
self
):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to node, Third, extract all functions to convert
to node.
Returns
-------
dict
use name to index nodes, key: node name, value: node
dict
use input (its name) to index nodes,
key: input, value: list of nodes that take this input
dict
use output (its name) to index nodes,
key: output, value: node that generates this output
"""
omit_useless_nodes
=
True
graph
=
self
.
trace
.
graph
_logger
.
debug
(
graph
)
# build output mapping, from output debugName to its node
output_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
for
x
in
n
.
outputs
()}
# build input mapping, from input debugName to its node
input_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
for
x
in
n
.
inputs
()}
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes
=
defaultdict
(
list
)
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes
=
defaultdict
(
list
)
nodes_py
=
GraphPy
()
for
node
in
graph
.
inputs
():
if
omit_useless_nodes
:
if
not
node
.
uses
():
# number of user of the node (= number of outputs/ fanout)
continue
if
node
.
type
().
kind
()
!=
'ClassType'
:
nodes_py
.
append
(
NodePyIO
(
node
,
'input'
))
self
.
leaf_modules
=
self
.
_extract_leaf_modules
()
module_to_type
=
{
name
:
parse_traced_name
(
module
.
_name
)
for
name
,
module
in
self
.
trace
.
named_modules
()}
# associate module name with their trace graph nodes
for
node
in
graph
.
nodes
():
module_name
=
self
.
_get_module_name
(
node
.
scopeName
())
if
module_name
in
self
.
leaf_modules
:
module_to_nodes
[
module_name
].
append
(
node
)
else
:
func_to_nodes
[
node
.
scopeName
()].
append
(
node
)
# build node group for module
for
module_name
,
node_cpps
in
module_to_nodes
.
items
():
node_group
=
self
.
_build_module_node_group
(
module_name
,
module_to_type
[
module_name
],
node_cpps
,
input_to_node
,
output_to_node
)
_logger
.
debug
(
'node_group: %s'
,
node_group
)
nodes_py
.
nodes_op
.
append
(
node_group
)
# each scope_name may have multiple funcs, we split them and create node for each of them
# build node group for torch.nn.functional
for
_
,
nodes
in
func_to_nodes
.
items
():
# extract non prim:: nodes
non_prim_nodes
=
list
()
for
node
in
nodes
:
if
not
node
.
kind
().
startswith
(
'prim::'
):
non_prim_nodes
.
append
(
node
)
# for each non prim node, expand it
for
node
in
non_prim_nodes
:
node_group
=
self
.
_expand_non_prim_node
(
node
,
nodes
,
input_to_node
,
output_to_node
)
nodes_py
.
nodes_op
.
append
(
node_group
)
# get shape infor for view (aten::view) func
if
node_group
.
op_type
in
[
'aten::view'
,
'aten::flatten'
]:
node_group
.
auxiliary
=
self
.
_extract_shape_info
(
node
)
for
node
in
graph
.
outputs
():
# Create sink nodes for output ops
node_py
=
NodePyIO
(
node
,
'output'
)
nodes_py
.
append
(
node_py
)
self
.
nodes_py
=
nodes_py
# build index
return
self
.
_build_index
(
self
.
nodes_py
.
nodes_op
)
def
find_predecessors
(
self
,
module_name
):
"""
Find predecessor node of the given node
Parameters
----------
module_name : str
The name of the node
Returns
-------
list
a list of nodes who are the given node's predecessor
"""
predecessors
=
[]
for
_input
in
self
.
name_to_node
[
module_name
].
inputs
:
if
not
_input
in
self
.
output_to_node
:
_logger
.
debug
(
"cannot find node with %s as its output"
,
_input
)
else
:
node_py
=
self
.
output_to_node
[
_input
]
predecessors
.
append
(
node_py
.
name
)
return
predecessors
def
find_successors
(
self
,
module_name
):
"""
Find successor nodes of the given node
Parameters
----------
module_name : str
The name of the node
Returns
-------
list
a list of nodes who are the given node's successor
"""
successors
=
[]
for
output
in
self
.
name_to_node
[
module_name
].
outputs
:
assert
output
in
self
.
input_to_node
,
"No node with input {}"
.
format
(
output
)
nodes_py
=
self
.
input_to_node
[
output
]
for
node_py
in
nodes_py
:
successors
.
append
(
node_py
.
name
)
return
successors
src/sdk/pynni/nni/compression/speedup/torch/compress_modules.py
View file @
5c861676
...
@@ -12,6 +12,7 @@ replace_module = {
...
@@ -12,6 +12,7 @@ replace_module = {
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
'MaxPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'MaxPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'ReLU'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'ReLU'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'Linear'
:
lambda
module
,
mask
:
replace_linear
(
module
,
mask
)
'Linear'
:
lambda
module
,
mask
:
replace_linear
(
module
,
mask
)
}
}
...
...
src/sdk/pynni/nni/compression/speedup/torch/compressor.py
View file @
5c861676
...
@@ -2,9 +2,8 @@
...
@@ -2,9 +2,8 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
import
queue
import
re
import
torch
import
torch
from
nni._graph_utils
import
build_module_graph
from
.compress_modules
import
replace_module
from
.compress_modules
import
replace_module
from
.infer_shape
import
ModuleMasks
,
infer_from_mask
,
infer_from_inshape
,
infer_from_outshape
from
.infer_shape
import
ModuleMasks
,
infer_from_mask
,
infer_from_inshape
,
infer_from_outshape
...
@@ -33,38 +32,6 @@ def get_module_by_name(model, module_name):
...
@@ -33,38 +32,6 @@ def get_module_by_name(model, module_name):
leaf_module
=
getattr
(
model
,
name_list
[
-
1
])
leaf_module
=
getattr
(
model
,
name_list
[
-
1
])
return
model
,
leaf_module
return
model
,
leaf_module
class
GNode
:
"""
It is used to represent a node in model graph, in this graph a module is a node,
a function out of module (in ```forward``` function) could also be a node.
"""
def
__init__
(
self
,
node_name
,
node_type
,
op_type
,
inputs
,
outputs
,
nodes
):
"""
Parameters
----------
node_name : str
It is module name if the node is a module, it is ```scope_name.node_kind.seq``` if it is a func
node_type : str
It only has two options: `module` or `func`
op_type : str
The operation type of the module or func
inputs : list of str
All the inputs of this node, each element is debugName of one input
outputs : list of str
All the outputs of this node, each element is debugName of one output
nodes : list of node
All the trace graph nodes included in this module or func
"""
self
.
name
=
node_name
self
.
type
=
node_type
self
.
op_type
=
op_type
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
nodes
=
nodes
# store supplementary information for different op types
# for example, for ```view``` it stores the shape of its input and output
self
.
auxiliary
=
None
class
ModelSpeedup
:
class
ModelSpeedup
:
"""
"""
This class is to speedup the model with provided weight mask
This class is to speedup the model with provided weight mask
...
@@ -84,347 +51,9 @@ class ModelSpeedup:
...
@@ -84,347 +51,9 @@ class ModelSpeedup:
the device on which masks are placed, same to map_location in ```torch.load```
the device on which masks are placed, same to map_location in ```torch.load```
"""
"""
self
.
bound_model
=
model
self
.
bound_model
=
model
self
.
dummy_input
=
dummy_input
self
.
masks
=
torch
.
load
(
masks_file
,
map_location
)
self
.
masks
=
torch
.
load
(
masks_file
,
map_location
)
self
.
is_training
=
model
.
training
# to obtain forward graph, model should be in ```eval``` mode
if
self
.
is_training
:
model
.
eval
()
self
.
trace_graph
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
if
self
.
is_training
:
model
.
train
()
self
.
inferred_masks
=
dict
()
# key: module_name, value: ModuleMasks
self
.
inferred_masks
=
dict
()
# key: module_name, value: ModuleMasks
self
.
g_nodes
=
list
()
self
.
torch_graph
=
build_module_graph
(
model
,
dummy_input
)
self
.
global_count
=
0
self
.
name_to_gnode
,
self
.
input_to_gnode
,
self
.
output_to_gnode
=
self
.
_build_graph
()
def
_build_index_for_gnodes
(
self
,
g_nodes
):
"""
Build indexes for quick search
Parameters
----------
g_nodes : list of GNode
All the g_node in processed model graph
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
name_to_gnode
=
dict
()
input_to_gnode
=
dict
()
output_to_gnode
=
dict
()
for
node
in
g_nodes
:
name_to_gnode
[
node
.
name
]
=
node
for
_input
in
node
.
inputs
:
if
_input
in
input_to_gnode
:
input_to_gnode
[
_input
].
append
(
node
)
else
:
input_to_gnode
[
_input
]
=
[
node
]
for
output
in
node
.
outputs
:
assert
not
output
in
output_to_gnode
,
\
"One output cannot be generated by multiple nodes"
output_to_gnode
[
output
]
=
node
return
name_to_gnode
,
input_to_gnode
,
output_to_gnode
def
_expand_non_prim_node
(
self
,
node
,
nodes
,
input_to_node
,
output_to_node
):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
trivial op which are label by ```prim::```, some of them are not such ops which is call
non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct
a GNode.
Parameters
----------
node : trace graph node
The non-prim node to expand
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
Returns
-------
GNode
the expanded non-prim node in GNode format
"""
# TODO: scope name could be empty
node_name
=
'.'
.
join
([
node
.
scopeName
(),
node
.
kind
(),
str
(
self
.
global_count
)])
_logger
.
debug
(
"expand non-prim node, node name: %s"
,
node_name
)
self
.
global_count
+=
1
op_type
=
node
.
kind
()
node_group
=
[
node
]
inputs
=
list
()
outputs
=
list
()
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
while
not
node_queue
.
empty
():
curr_node
=
node_queue
.
get
()
for
_input
in
curr_node
.
inputs
():
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
predecessor_node
=
output_to_node
[
input_name
]
if
predecessor_node
.
kind
().
startswith
(
'prim::'
):
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
else
:
inputs
.
append
(
input_name
)
else
:
inputs
.
append
(
input_name
)
for
output
in
node
.
outputs
():
outputs
.
append
(
output
.
debugName
())
g_node
=
GNode
(
node_name
,
'func'
,
op_type
,
inputs
,
outputs
,
node_group
)
return
g_node
def
_extract_shape_info
(
self
,
node
):
"""
Extract the shape information of ```aten::view``` node
Parameters
----------
node : trace graph node
It should be ```aten::view``` node
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
t_input
=
None
for
_input
in
node
.
inputs
():
t_input
=
_input
break
t_output
=
node
.
output
()
assert
isinstance
(
t_input
.
type
(),
torch
.
_C
.
TensorType
)
assert
isinstance
(
t_output
.
type
(),
torch
.
_C
.
TensorType
)
in_shape
=
t_input
.
type
().
sizes
()
out_shape
=
t_output
.
type
().
sizes
()
return
{
'in_shape'
:
in_shape
,
'out_shape'
:
out_shape
}
def
_extract_leaf_modules
(
self
,
graph
):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Parameters
----------
graph : jit trace graph
the graph generated from jit trace
Returns
-------
list
a list of scope name of all the leaf modules
"""
class
SNode
:
def
__init__
(
self
,
name
):
self
.
sname
=
name
self
.
childs
=
{}
root
=
None
for
node
in
graph
.
nodes
():
scope_name
=
node
.
scopeName
()
if
scope_name
==
''
:
continue
segs
=
scope_name
.
split
(
'/'
)
if
root
is
None
:
root
=
SNode
(
segs
[
0
])
curr
=
root
for
seg
in
segs
[
1
:]:
if
not
seg
in
curr
.
childs
:
curr
.
childs
[
seg
]
=
SNode
(
seg
)
curr
=
curr
.
childs
[
seg
]
leaf_nodes
=
[]
def
traverse_tree
(
node
,
scope_name
):
if
scope_name
==
''
:
sn
=
node
.
sname
else
:
sn
=
scope_name
+
'/'
+
node
.
sname
if
not
node
.
childs
:
if
node
.
sname
[
-
1
]
==
']'
:
leaf_nodes
.
append
(
sn
)
else
:
for
key
in
node
.
childs
:
traverse_tree
(
node
.
childs
[
key
],
sn
)
traverse_tree
(
root
,
''
)
return
leaf_nodes
def
_build_graph
(
self
):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to GNode, Third, extract all functions to convert
to GNode.
Returns
-------
dict
use name to index g_nodes, key: node name, value: g_node
dict
use input (its name) to index g_nodes,
key: input, value: list of g_nodes that take this input
dict
use output (its name) to index g_nodes,
key: output, value: g_node that generates this output
"""
graph
=
self
.
trace_graph
.
graph
# if torch 1.4.0 is used, consider run torch._C._jit_pass_inline(graph) here
_logger
.
debug
(
graph
)
# build output mapping, from output debugName to its node
output_to_node
=
dict
()
# build input mapping, from input debugName to its node
input_to_node
=
dict
()
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes
=
dict
()
# module name to its type
module_to_type
=
dict
()
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes
=
dict
()
graph_inputs
=
list
()
graph_outputs
=
list
()
for
_input
in
graph
.
inputs
():
graph_inputs
.
append
(
_input
.
debugName
())
for
output
in
graph
.
outputs
():
graph_outputs
.
append
(
output
.
debugName
())
leaf_modules
=
self
.
_extract_leaf_modules
(
graph
)
_logger
.
debug
(
leaf_modules
)
for
node
in
graph
.
nodes
():
# populate output_to_node and input_to_node
for
output
in
node
.
outputs
():
output_name
=
output
.
debugName
()
output_to_node
[
output_name
]
=
node
for
_input
in
node
.
inputs
():
input_name
=
_input
.
debugName
()
input_to_node
[
input_name
]
=
node
scope_name
=
node
.
scopeName
()
# example: scope_name, 'MyCell/Linear[linear]'
# if module_name is empty, it is not a module
if
not
scope_name
in
leaf_modules
:
if
scope_name
==
''
:
continue
else
:
if
scope_name
in
func_to_nodes
:
func_to_nodes
[
scope_name
].
append
(
node
)
else
:
func_to_nodes
[
scope_name
]
=
[
node
]
else
:
module_name_slices
=
re
.
findall
(
r
'\[(.*?)\]'
,
scope_name
)
module_name
=
'.'
.
join
(
module_name_slices
)
scope_slice
=
scope_name
.
split
(
'/'
)[
-
1
]
module_type
=
scope_slice
.
split
(
'['
)[
0
]
module_to_type
[
module_name
]
=
module_type
if
module_name
in
module_to_nodes
:
module_to_nodes
[
module_name
].
append
(
node
)
else
:
module_to_nodes
[
module_name
]
=
[
node
]
# construct GNode from module
for
module_name
,
nodes
in
module_to_nodes
.
items
():
inputs
=
set
()
outputs
=
set
()
for
node
in
nodes
:
for
output
in
node
.
outputs
():
outputs
.
add
(
output
.
debugName
())
for
_input
in
node
.
inputs
():
inputs
.
add
(
_input
.
debugName
())
m_inputs
=
list
()
m_outputs
=
list
()
for
output
in
outputs
:
# TODO: one input could be the input of multiple nodes
if
not
output
in
input_to_node
and
output
in
graph_outputs
:
m_outputs
.
append
(
output
)
elif
not
input_to_node
[
output
]
in
nodes
:
m_outputs
.
append
(
output
)
for
_input
in
inputs
:
if
not
_input
in
output_to_node
and
_input
in
graph_inputs
:
m_inputs
.
append
(
_input
)
elif
not
output_to_node
[
_input
]
in
nodes
:
m_inputs
.
append
(
_input
)
if
module_name
==
''
:
_logger
.
warning
(
"module_name is empty string"
)
g_node
=
GNode
(
module_name
,
'module'
,
module_to_type
[
module_name
],
m_inputs
,
m_outputs
,
nodes
)
self
.
g_nodes
.
append
(
g_node
)
# each scope_name may have multiple funcs, we split them and create GNode for each of them
for
scope_name
,
nodes
in
func_to_nodes
.
items
():
# extract non prim:: nodes
non_prim_nodes
=
list
()
for
node
in
nodes
:
if
not
node
.
kind
().
startswith
(
'prim::'
):
non_prim_nodes
.
append
(
node
)
# for each non prim node, expand it has a GNode
for
node
in
non_prim_nodes
:
g_node
=
self
.
_expand_non_prim_node
(
node
,
nodes
,
input_to_node
,
output_to_node
)
self
.
g_nodes
.
append
(
g_node
)
# get shape infor for view (aten::view) func
if
g_node
.
op_type
==
'aten::view'
:
g_node
.
auxiliary
=
self
.
_extract_shape_info
(
node
)
# build index for g_nodes
name_to_gnode
,
input_to_gnode
,
output_to_gnode
=
self
.
_build_index_for_gnodes
(
self
.
g_nodes
)
return
name_to_gnode
,
input_to_gnode
,
output_to_gnode
def
_find_predecessors
(
self
,
module_name
):
"""
Find predecessor GNode of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's predecessor
"""
predecessors
=
[]
for
_input
in
self
.
name_to_gnode
[
module_name
].
inputs
:
if
not
_input
in
self
.
output_to_gnode
:
_logger
.
debug
(
"cannot find gnode with %s as its output"
,
_input
)
else
:
g_node
=
self
.
output_to_gnode
[
_input
]
predecessors
.
append
(
g_node
.
name
)
return
predecessors
def
_find_successors
(
self
,
module_name
):
"""
Find successor GNodes of the given GNode
Parameters
----------
module_name : str
The name of the GNode
Returns
-------
list
a list of GNodes who are the given GNode's successor
"""
successors
=
[]
for
output
in
self
.
name_to_gnode
[
module_name
].
outputs
:
assert
output
in
self
.
input_to_gnode
,
"No gnode with input {}"
.
format
(
output
)
g_nodes
=
self
.
input_to_gnode
[
output
]
for
g_node
in
g_nodes
:
successors
.
append
(
g_node
.
name
)
return
successors
def
infer_module_mask
(
self
,
module_name
,
mask
=
None
,
in_shape
=
None
,
out_shape
=
None
):
def
infer_module_mask
(
self
,
module_name
,
mask
=
None
,
in_shape
=
None
,
out_shape
=
None
):
"""
"""
...
@@ -441,13 +70,13 @@ class ModelSpeedup:
...
@@ -441,13 +70,13 @@ class ModelSpeedup:
Parameters
Parameters
----------
----------
module_name : str
module_name : str
The name of the
GN
ode
The name of the
n
ode
mask : tensor of mask or ModuleMasks
mask : tensor of mask or ModuleMasks
Mask of the weights in this
GN
ode (i.e., module)
Mask of the weights in this
n
ode (i.e., module)
in_shape : ModuleMasks
in_shape : ModuleMasks
Input shape of this
GN
ode
Input shape of this
n
ode
out_shape : ModuleMasks
out_shape : ModuleMasks
Output shape of this
GN
ode
Output shape of this
n
ode
"""
"""
input_cmask
=
output_cmask
=
None
input_cmask
=
output_cmask
=
None
if
module_name
in
self
.
inferred_masks
:
if
module_name
in
self
.
inferred_masks
:
...
@@ -456,7 +85,7 @@ class ModelSpeedup:
...
@@ -456,7 +85,7 @@ class ModelSpeedup:
module_masks
=
ModuleMasks
(
module_name
)
module_masks
=
ModuleMasks
(
module_name
)
self
.
inferred_masks
[
module_name
]
=
module_masks
self
.
inferred_masks
[
module_name
]
=
module_masks
m_type
=
self
.
name_to_
g
node
[
module_name
].
op_type
m_type
=
self
.
torch_graph
.
name_to_node
[
module_name
].
op_type
_logger
.
debug
(
"infer mask of module %s with op_type %s"
,
module_name
,
m_type
)
_logger
.
debug
(
"infer mask of module %s with op_type %s"
,
module_name
,
m_type
)
if
mask
is
not
None
:
if
mask
is
not
None
:
_logger
.
debug
(
"mask is not None"
)
_logger
.
debug
(
"mask is not None"
)
...
@@ -471,10 +100,10 @@ class ModelSpeedup:
...
@@ -471,10 +100,10 @@ class ModelSpeedup:
raise
RuntimeError
(
raise
RuntimeError
(
"Has not supported infering output shape from input shape for module/function: `{}`, {}"
"Has not supported infering output shape from input shape for module/function: `{}`, {}"
.
format
(
m_type
,
module_name
))
.
format
(
m_type
,
module_name
))
if
m_type
==
'aten::view'
:
if
m_type
in
[
'aten::view'
,
'aten::flatten'
]
:
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
,
in_shape
,
self
.
name_to_
g
node
[
module_name
].
auxiliary
)
self
.
torch_graph
.
name_to_node
[
module_name
].
auxiliary
)
else
:
else
:
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
)
output_cmask
=
infer_from_inshape
[
m_type
](
module_masks
,
in_shape
)
if
out_shape
is
not
None
:
if
out_shape
is
not
None
:
...
@@ -486,11 +115,11 @@ class ModelSpeedup:
...
@@ -486,11 +115,11 @@ class ModelSpeedup:
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
)
input_cmask
=
infer_from_outshape
[
m_type
](
module_masks
,
out_shape
)
if
input_cmask
:
if
input_cmask
:
predecessors
=
self
.
_
find_predecessors
(
module_name
)
predecessors
=
self
.
torch_graph
.
find_predecessors
(
module_name
)
for
_module_name
in
predecessors
:
for
_module_name
in
predecessors
:
self
.
infer_module_mask
(
_module_name
,
out_shape
=
input_cmask
)
self
.
infer_module_mask
(
_module_name
,
out_shape
=
input_cmask
)
if
output_cmask
:
if
output_cmask
:
successors
=
self
.
_
find_successors
(
module_name
)
successors
=
self
.
torch_graph
.
find_successors
(
module_name
)
for
_module_name
in
successors
:
for
_module_name
in
successors
:
self
.
infer_module_mask
(
_module_name
,
in_shape
=
output_cmask
)
self
.
infer_module_mask
(
_module_name
,
in_shape
=
output_cmask
)
...
@@ -511,7 +140,7 @@ class ModelSpeedup:
...
@@ -511,7 +140,7 @@ class ModelSpeedup:
is that ```func``` should be not required to be replaced.
is that ```func``` should be not required to be replaced.
"""
"""
for
module_name
in
self
.
inferred_masks
:
for
module_name
in
self
.
inferred_masks
:
g_node
=
self
.
name_to_
g
node
[
module_name
]
g_node
=
self
.
torch_graph
.
name_to_node
[
module_name
]
_logger
.
debug
(
"replace %s, in %s type, with op_type %s"
,
_logger
.
debug
(
"replace %s, in %s type, with op_type %s"
,
module_name
,
g_node
.
type
,
g_node
.
op_type
)
module_name
,
g_node
.
type
,
g_node
.
op_type
)
if
g_node
.
type
==
'module'
:
if
g_node
.
type
==
'module'
:
...
@@ -526,7 +155,7 @@ class ModelSpeedup:
...
@@ -526,7 +155,7 @@ class ModelSpeedup:
_logger
.
info
(
"Warning: cannot replace (name: %s, op_type: %s) which is func type"
,
_logger
.
info
(
"Warning: cannot replace (name: %s, op_type: %s) which is func type"
,
module_name
,
g_node
.
op_type
)
module_name
,
g_node
.
op_type
)
else
:
else
:
raise
RuntimeError
(
"Unsupported
GN
ode type: {}"
.
format
(
g_node
.
type
))
raise
RuntimeError
(
"Unsupported
n
ode type: {}"
.
format
(
g_node
.
type
))
def
speedup_model
(
self
):
def
speedup_model
(
self
):
"""
"""
...
@@ -540,8 +169,3 @@ class ModelSpeedup:
...
@@ -540,8 +169,3 @@ class ModelSpeedup:
_logger
.
info
(
"replace compressed modules..."
)
_logger
.
info
(
"replace compressed modules..."
)
self
.
replace_compressed_modules
()
self
.
replace_compressed_modules
()
_logger
.
info
(
"speedup done"
)
_logger
.
info
(
"speedup done"
)
# resume the model mode to that before the model is speed up
if
self
.
is_training
:
self
.
bound_model
.
train
()
else
:
self
.
bound_model
.
eval
()
\ No newline at end of file
src/sdk/pynni/nni/compression/speedup/torch/infer_shape.py
View file @
5c861676
...
@@ -83,6 +83,9 @@ class CoarseMask:
...
@@ -83,6 +83,9 @@ class CoarseMask:
cmask
.
mask_index
[
i
])
cmask
.
mask_index
[
i
])
return
self
.
mask_index
return
self
.
mask_index
def
__repr__
(
self
):
return
'mask_index: {}'
.
format
(
self
.
mask_index
)
class
ModuleMasks
:
class
ModuleMasks
:
"""
"""
The masks of a module, including the masks for weights, inputs, output
The masks of a module, including the masks for weights, inputs, output
...
@@ -128,6 +131,11 @@ class ModuleMasks:
...
@@ -128,6 +131,11 @@ class ModuleMasks:
"""
"""
self
.
output_mask
=
mask
self
.
output_mask
=
mask
def
__repr__
(
self
):
return
'input_mask: {}, output_mask: {}, param_masks: {}'
.
format
(
self
.
input_mask
,
self
.
output_mask
,
self
.
param_masks
)
"""
"""
Infer input and output shape of a module/function from its weight mask
Infer input and output shape of a module/function from its weight mask
"""
"""
...
@@ -147,8 +155,10 @@ infer_from_inshape = {
...
@@ -147,8 +155,10 @@ infer_from_inshape = {
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'AvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'AvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::size'
:
lambda
module_masks
,
mask
:
size_inshape
(
module_masks
,
mask
),
'aten::size'
:
lambda
module_masks
,
mask
:
size_inshape
(
module_masks
,
mask
),
'aten::view'
:
lambda
module_masks
,
mask
,
shape
:
view_inshape
(
module_masks
,
mask
,
shape
),
'aten::view'
:
lambda
module_masks
,
mask
,
shape
:
view_inshape
(
module_masks
,
mask
,
shape
),
'aten::flatten'
:
lambda
module_masks
,
mask
,
shape
:
view_inshape
(
module_masks
,
mask
,
shape
),
# support only start_dim=1
'Linear'
:
lambda
module_masks
,
mask
:
linear_inshape
(
module_masks
,
mask
),
'Linear'
:
lambda
module_masks
,
mask
:
linear_inshape
(
module_masks
,
mask
),
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_inshape
(
module_masks
,
mask
)
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_inshape
(
module_masks
,
mask
)
}
}
...
...
src/sdk/pynni/nni/nas/pytorch/_graph_utils.py
deleted
100644 → 0
View file @
ac238f01
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# This file is copied from PyTorch 1.4, with bug fixes.
# Likely to be removed in future.
import
torch
from
tensorboard.compat.proto.config_pb2
import
RunMetadata
from
tensorboard.compat.proto.graph_pb2
import
GraphDef
from
tensorboard.compat.proto.step_stats_pb2
import
StepStats
,
DeviceStepStats
from
tensorboard.compat.proto.versions_pb2
import
VersionDef
from
torch.utils.tensorboard._pytorch_graph
import
GraphPy
,
CLASSTYPE_KIND
,
GETATTR_KIND
,
NodePyIO
,
NodePyOP
def
parse
(
graph
,
trace
,
args
=
None
,
omit_useless_nodes
=
True
):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
n_inputs
=
len
(
args
)
scope
=
{}
nodes_py
=
GraphPy
()
for
node
in
graph
.
inputs
():
if
omit_useless_nodes
:
if
len
(
node
.
uses
())
==
0
:
# number of user of the node (= number of outputs/ fanout)
continue
if
node
.
type
().
kind
()
!=
CLASSTYPE_KIND
:
nodes_py
.
append
(
NodePyIO
(
node
,
'input'
))
attr_to_scope
=
dict
()
node_to_name
=
lambda
d
:
str
(
d
).
split
(
":"
)[
0
].
strip
()
for
node
in
graph
.
nodes
():
if
node
.
kind
()
==
GETATTR_KIND
:
attr_name
=
node
.
s
(
'name'
)
node_name
=
node_to_name
(
node
)
parent
=
node
.
input
().
node
()
if
parent
.
kind
()
==
GETATTR_KIND
:
# If the parent node is not the top-level "self" node
parent_attr_name
=
parent
.
s
(
'name'
)
parent_scope
=
attr_to_scope
[
node_to_name
(
parent
)]
attr_scope
=
parent_scope
.
split
(
'/'
)[
-
1
]
attr_to_scope
[
node_name
]
=
'{}/{}.{}'
.
format
(
parent_scope
,
attr_scope
,
attr_name
)
else
:
attr_to_scope
[
node_name
]
=
'__module.{}'
.
format
(
attr_name
)
# We don't need classtype nodes; scope will provide this information
if
node
.
output
().
type
().
kind
()
!=
CLASSTYPE_KIND
:
node_py
=
NodePyOP
(
node
)
node_py
.
scopeName
=
attr_to_scope
[
node_name
]
nodes_py
.
append
(
node_py
)
else
:
nodes_py
.
append
(
NodePyOP
(
node
))
for
i
,
node
in
enumerate
(
graph
.
outputs
()):
# Create sink nodes for output ops
node_py
=
NodePyIO
(
node
,
'output'
)
node_py
.
debugName
=
"output.{}"
.
format
(
i
+
1
)
node_py
.
inputs
=
[
node
.
debugName
()]
nodes_py
.
append
(
node_py
)
def
parse_traced_name
(
module_name
):
prefix
=
'TracedModule['
suffix
=
']'
if
module_name
.
startswith
(
prefix
)
and
module_name
.
endswith
(
suffix
):
module_name
=
module_name
[
len
(
prefix
):
-
len
(
suffix
)]
return
module_name
alias_to_name
=
dict
()
base_name
=
parse_traced_name
(
trace
.
_name
)
for
name
,
module
in
trace
.
named_modules
(
prefix
=
'__module'
):
mod_name
=
parse_traced_name
(
module
.
_name
)
attr_name
=
name
.
split
(
'.'
)[
-
1
]
alias_to_name
[
name
]
=
'{}[{}]'
.
format
(
mod_name
,
attr_name
)
for
node
in
nodes_py
.
nodes_op
:
module_aliases
=
node
.
scopeName
.
split
(
'/'
)[
-
1
].
split
(
'.'
)
module_name
=
''
for
i
,
alias
in
enumerate
(
module_aliases
):
if
i
==
0
:
module_name
=
alias
node
.
scopeName
=
base_name
else
:
module_name
+=
'.'
+
alias
node
.
scopeName
+=
'/'
+
(
alias_to_name
[
module_name
]
if
module_name
in
alias_to_name
else
alias
)
nodes_py
.
populate_namespace_from_OP_to_IO
()
return
nodes_py
.
to_proto
()
def
graph
(
model
,
args
,
verbose
=
False
):
"""
This method processes a PyTorch model and produces a `GraphDef` proto
that can be logged to TensorBoard.
Args:
model (PyTorch module): The model to be parsed.
args (tuple): input tensor[s] for the model.
verbose (bool): Whether to print out verbose information while
processing.
"""
with
torch
.
onnx
.
set_training
(
model
,
False
):
# TODO: move outside of torch.onnx?
try
:
trace
=
torch
.
jit
.
trace
(
model
,
args
)
graph
=
trace
.
graph
torch
.
_C
.
_jit_pass_inline
(
graph
)
except
RuntimeError
as
e
:
print
(
e
)
print
(
'Error occurs, No graph saved'
)
raise
e
if
verbose
:
print
(
graph
)
list_of_nodes
=
parse
(
graph
,
trace
,
args
)
# We are hardcoding that this was run on CPU even though it might have actually
# run on GPU. Note this is what is shown in TensorBoard and has no bearing
# on actual execution.
# TODO: See if we can extract GPU vs CPU information from the PyTorch model
# and pass it correctly to TensorBoard.
#
# Definition of StepStats and DeviceStepStats can be found at
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
# and
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
stepstats
=
RunMetadata
(
step_stats
=
StepStats
(
dev_stats
=
[
DeviceStepStats
(
device
=
"/device:CPU:0"
)]))
return
GraphDef
(
node
=
list_of_nodes
,
versions
=
VersionDef
(
producer
=
22
)),
stepstats
# The producer version has been reverse engineered from standard
# TensorBoard logged data.
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
5c861676
...
@@ -107,12 +107,12 @@ class Mutator(BaseMutator):
...
@@ -107,12 +107,12 @@ class Mutator(BaseMutator):
"""
"""
if
not
torch
.
__version__
.
startswith
(
"1.4"
):
if
not
torch
.
__version__
.
startswith
(
"1.4"
):
logger
.
warning
(
"Graph is only tested with PyTorch 1.4. Other versions might not work."
)
logger
.
warning
(
"Graph is only tested with PyTorch 1.4. Other versions might not work."
)
from
._graph_utils
import
graph
from
nni
._graph_utils
import
build_
graph
from
google.protobuf
import
json_format
from
google.protobuf
import
json_format
# protobuf should be installed as long as tensorboard is installed
# protobuf should be installed as long as tensorboard is installed
try
:
try
:
self
.
_connect_all
=
True
self
.
_connect_all
=
True
graph_def
,
_
=
graph
(
self
.
model
,
inputs
,
verbose
=
False
)
graph_def
,
_
=
build_
graph
(
self
.
model
,
inputs
,
verbose
=
False
)
result
=
json_format
.
MessageToDict
(
graph_def
)
result
=
json_format
.
MessageToDict
(
graph_def
)
finally
:
finally
:
self
.
_connect_all
=
False
self
.
_connect_all
=
False
...
...
src/sdk/pynni/tests/expect/test_graph_module1.expect
0 → 100644
View file @
5c861676
node {
name: "input/input"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "myLinear/Linear[l]/22"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "myLinear/Linear[l]/bias/17"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "myLinear/Linear[l]/weight/18"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "myLinear/Linear[l]/19"
op: "aten::t"
input: "myLinear/Linear[l]/weight/18"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 3
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "myLinear/Linear[l]/20"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/21"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/22"
op: "aten::addmm"
input: "myLinear/Linear[l]/bias/17"
input: "input/input"
input: "myLinear/Linear[l]/19"
input: "myLinear/Linear[l]/20"
input: "myLinear/Linear[l]/21"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
src/sdk/pynni/tests/expect/test_graph_module2.expect
0 → 100644
View file @
5c861676
node {
name: "input/input.1"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "input/input.1"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "MyModule/Linear[weight]/bias/49"
op: "prim::GetAttr"
input: "MyModule/Linear[weight]/weight/35"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/Linear[weight]/weight/50"
op: "prim::GetAttr"
input: "MyModule/Linear[weight]/weight/35"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/Linear[weight]/51"
op: "aten::t"
input: "MyModule/Linear[weight]/weight/50"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[weight]/52"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[weight]/53"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[weight]/54"
op: "aten::addmm"
input: "MyModule/Linear[weight]/bias/49"
input: "input/input.1"
input: "MyModule/Linear[weight]/51"
input: "MyModule/Linear[weight]/52"
input: "MyModule/Linear[weight]/53"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[bias]/bias/55"
op: "prim::GetAttr"
input: "MyModule/Linear[bias]/weight/38"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/Linear[bias]/weight/56"
op: "prim::GetAttr"
input: "MyModule/Linear[bias]/weight/38"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/Linear[bias]/57"
op: "aten::t"
input: "MyModule/Linear[bias]/weight/56"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/Linear[bias]/58"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[bias]/59"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/Linear[bias]/60"
op: "aten::addmm"
input: "MyModule/Linear[bias]/bias/55"
input: "input/input.1"
input: "MyModule/Linear[bias]/57"
input: "MyModule/Linear[bias]/58"
input: "MyModule/Linear[bias]/59"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/23"
op: "prim::ListConstruct"
input: "MyModule/Linear[weight]/54"
input: "MyModule/Linear[bias]/60"
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/24"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/input"
op: "aten::cat"
input: "MyModule/23"
input: "MyModule/24"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 6
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/61"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
src/sdk/pynni/tests/expect/test_graph_module3.expect
0 → 100644
View file @
5c861676
node {
name: "input/input.1"
op: "IO Node"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 5
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "output/output.1"
op: "IO Node"
input: "MyModule/ModuleList[module]/Linear[1]/46"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: ""
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/bias/35"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[0]/weight/26"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/weight/36"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[0]/weight/26"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/37"
op: "aten::t"
input: "MyModule/ModuleList[module]/Linear[0]/weight/36"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 5
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/38"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/39"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[0]/input"
op: "aten::addmm"
input: "MyModule/ModuleList[module]/Linear[0]/bias/35"
input: "input/input.1"
input: "MyModule/ModuleList[module]/Linear[0]/37"
input: "MyModule/ModuleList[module]/Linear[0]/38"
input: "MyModule/ModuleList[module]/Linear[0]/39"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 3
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/bias/41"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[1]/weight/30"
attr {
key: "attr"
value {
s: "{ name : bias }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/weight/42"
op: "prim::GetAttr"
input: "MyModule/ModuleList[module]/Linear[1]/weight/30"
attr {
key: "attr"
value {
s: "{ name : weight }"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/43"
op: "aten::t"
input: "MyModule/ModuleList[module]/Linear[1]/weight/42"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 3
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/44"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/45"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "MyModule/ModuleList[module]/Linear[1]/46"
op: "aten::addmm"
input: "MyModule/ModuleList[module]/Linear[1]/bias/41"
input: "MyModule/ModuleList[module]/Linear[0]/input"
input: "MyModule/ModuleList[module]/Linear[1]/43"
input: "MyModule/ModuleList[module]/Linear[1]/44"
input: "MyModule/ModuleList[module]/Linear[1]/45"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 4
}
dim {
size: 1
}
}
}
}
}
attr {
key: "attr"
value {
s: "{}"
}
}
}
versions {
producer: 22
}
src/sdk/pynni/tests/test_graph_utils.py
0 → 100644
View file @
5c861676
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
sys
import
os
import
math
import
uuid
import
shutil
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
tensorboard.compat.proto.graph_pb2
import
GraphDef
from
google.protobuf
import
text_format
import
unittest
from
unittest
import
TestCase
,
main
from
nni._graph_utils
import
build_module_graph
,
build_graph
class
BackboneModel1
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
1
,
1
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv1
(
x
)
class
BackboneModel2
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
self
.
conv1
.
out_channels
)
self
.
bn2
=
nn
.
BatchNorm2d
(
self
.
conv2
.
out_channels
)
self
.
fc1
=
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
nn
.
Linear
(
500
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
x
class
BigModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
backbone1
=
BackboneModel1
()
self
.
backbone2
=
BackboneModel2
()
self
.
fc3
=
nn
.
Linear
(
10
,
2
)
def
forward
(
self
,
x
):
x
=
self
.
backbone1
(
x
)
x
=
self
.
backbone2
(
x
)
x
=
self
.
fc3
(
x
)
return
x
class
GraphUtilsTestCase
(
TestCase
):
def
test_build_module_graph
(
self
):
big_model
=
BigModel
()
g
=
build_module_graph
(
big_model
,
torch
.
randn
(
2
,
1
,
28
,
28
))
print
(
g
.
name_to_node
.
keys
())
leaf_modules
=
set
([
'backbone1.conv1'
,
'backbone2.bn1'
,
'backbone2.bn2'
,
'backbone2.conv1'
,
'backbone2.conv2'
,
'backbone2.fc1'
,
'backbone2.fc2'
,
'fc3'
])
assert
set
(
g
.
leaf_modules
)
==
leaf_modules
assert
not
leaf_modules
-
set
(
g
.
name_to_node
.
keys
())
assert
g
.
find_successors
(
'backbone2.conv1'
)
==
[
'backbone2.bn1'
]
assert
g
.
find_successors
(
'backbone2.conv2'
)
==
[
'backbone2.bn2'
]
assert
g
.
find_predecessors
(
'backbone2.bn1'
)
==
[
'backbone2.conv1'
]
assert
g
.
find_predecessors
(
'backbone2.bn2'
)
==
[
'backbone2.conv2'
]
def
_test_graph
(
self
,
model
,
dummy_input
,
expected_file
):
actual_proto
,
_
=
build_graph
(
model
,
dummy_input
)
assert
os
.
path
.
exists
(
expected_file
),
expected_file
with
open
(
expected_file
,
"r"
)
as
f
:
expected_str
=
f
.
read
()
expected_proto
=
GraphDef
()
text_format
.
Parse
(
expected_str
,
expected_proto
)
self
.
assertEquals
(
len
(
expected_proto
.
node
),
len
(
actual_proto
.
node
))
for
i
in
range
(
len
(
expected_proto
.
node
)):
expected_node
=
expected_proto
.
node
[
i
]
actual_node
=
actual_proto
.
node
[
i
]
self
.
assertEquals
(
expected_node
.
name
,
actual_node
.
name
)
self
.
assertEquals
(
expected_node
.
op
,
actual_node
.
op
)
self
.
assertEquals
(
expected_node
.
input
,
actual_node
.
input
)
self
.
assertEquals
(
expected_node
.
device
,
actual_node
.
device
)
self
.
assertEquals
(
sorted
(
expected_node
.
attr
.
keys
()),
sorted
(
actual_node
.
attr
.
keys
()))
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4.0"
,
"not supported"
)
def
test_graph_module1
(
self
):
dummy_input
=
(
torch
.
zeros
(
1
,
3
),)
class
myLinear
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
(
myLinear
,
self
).
__init__
()
self
.
l
=
torch
.
nn
.
Linear
(
3
,
5
)
def
forward
(
self
,
x
):
return
self
.
l
(
x
)
self
.
_test_graph
(
myLinear
(),
dummy_input
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"expect"
,
"test_graph_module1.expect"
)
)
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4.0"
,
"not supported"
)
def
test_graph_module2
(
self
):
class
MyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
weight
=
nn
.
Linear
(
5
,
3
)
self
.
bias
=
nn
.
Linear
(
5
,
3
)
self
.
module
=
nn
.
Linear
(
6
,
1
)
def
forward
(
self
,
x
):
tensors
=
[
self
.
weight
(
x
),
self
.
bias
(
x
)]
self
.
module
(
torch
.
cat
(
tensors
,
dim
=
1
))
return
x
self
.
_test_graph
(
MyModule
(),
torch
.
randn
(
4
,
5
),
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"expect"
,
"test_graph_module2.expect"
)
)
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.4.0"
,
"not supported"
)
def
test_graph_module3
(
self
):
class
MyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
ModuleList
([
nn
.
Linear
(
5
,
3
),
nn
.
Linear
(
3
,
1
)
])
def
forward
(
self
,
x
):
x
=
self
.
module
[
0
](
x
)
x
=
self
.
module
[
1
](
x
)
return
x
self
.
_test_graph
(
MyModule
(),
torch
.
randn
(
4
,
5
),
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"expect"
,
"test_graph_module3.expect"
)
)
if
__name__
==
'__main__'
:
main
()
src/sdk/pynni/tests/test_model_speedup.py
0 → 100644
View file @
5c861676
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torchvision.models.vgg
import
vgg16
from
torchvision.models.resnet
import
resnet18
from
unittest
import
TestCase
,
main
from
nni.compression.torch
import
L1FilterPruner
from
nni.compression.speedup.torch
import
ModelSpeedup
class
BackboneModel1
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
1
,
1
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv1
(
x
)
class
BackboneModel2
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
bn1
=
nn
.
BatchNorm2d
(
self
.
conv1
.
out_channels
)
self
.
bn2
=
nn
.
BatchNorm2d
(
self
.
conv2
.
out_channels
)
self
.
fc1
=
nn
.
Linear
(
4
*
4
*
50
,
500
)
self
.
fc2
=
nn
.
Linear
(
500
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
x
class
BigModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
backbone1
=
BackboneModel1
()
self
.
backbone2
=
BackboneModel2
()
self
.
fc3
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
BatchNorm1d
(
10
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Linear
(
10
,
2
)
)
def
forward
(
self
,
x
):
x
=
self
.
backbone1
(
x
)
x
=
self
.
backbone2
(
x
)
x
=
self
.
fc3
(
x
)
return
x
SPARSITY
=
0.5
def
prune_model_l1
(
model
):
config_list
=
[{
'sparsity'
:
SPARSITY
,
'op_types'
:
[
'Conv2d'
]
}]
pruner
=
L1FilterPruner
(
model
,
config_list
)
pruner
.
compress
()
pruner
.
export_model
(
model_path
=
'./11_model.pth'
,
mask_path
=
'./l1_mask.pth'
)
class
SpeedupTestCase
(
TestCase
):
def
test_speedup_vgg16
(
self
):
prune_model_l1
(
vgg16
())
model
=
vgg16
()
model
.
train
()
ms
=
ModelSpeedup
(
model
,
torch
.
randn
(
2
,
3
,
32
,
32
),
'./l1_mask.pth'
)
ms
.
speedup_model
()
orig_model
=
vgg16
()
assert
model
.
training
assert
model
.
features
[
2
].
out_channels
==
int
(
orig_model
.
features
[
2
].
out_channels
*
SPARSITY
)
assert
model
.
classifier
[
0
].
in_features
==
int
(
orig_model
.
classifier
[
0
].
in_features
*
SPARSITY
)
#def test_speedup_resnet(self):
#TODO support resnet
#model = resnet18()
def
test_speedup_bigmodel
(
self
):
prune_model_l1
(
BigModel
())
model
=
BigModel
()
model
.
train
()
ms
=
ModelSpeedup
(
model
,
torch
.
randn
(
2
,
1
,
28
,
28
),
'./l1_mask.pth'
)
ms
.
speedup_model
()
orig_model
=
BigModel
()
assert
model
.
training
assert
model
.
backbone2
.
conv1
.
out_channels
==
int
(
orig_model
.
backbone2
.
conv1
.
out_channels
*
SPARSITY
)
assert
model
.
backbone2
.
conv2
.
in_channels
==
int
(
orig_model
.
backbone2
.
conv2
.
in_channels
*
SPARSITY
)
assert
model
.
backbone2
.
conv2
.
out_channels
==
int
(
orig_model
.
backbone2
.
conv2
.
out_channels
*
SPARSITY
)
assert
model
.
backbone2
.
fc1
.
in_features
==
int
(
orig_model
.
backbone2
.
fc1
.
in_features
*
SPARSITY
)
def
tearDown
(
self
):
os
.
remove
(
'./11_model.pth'
)
os
.
remove
(
'./l1_mask.pth'
)
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment