Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f7b7edac
"src/include/blockwise_tensor_op.cuh" did not exist on "5d2cafcb24097f86d33f7c5243a3c0f3800854ec"
Unverified
Commit
f7b7edac
authored
Nov 23, 2020
by
chicm-ms
Committed by
GitHub
Nov 23, 2020
Browse files
graphutils supports torch17 (#3076)
parent
b6233e52
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
79 additions
and
48 deletions
+79
-48
nni/common/graph_utils.py
nni/common/graph_utils.py
+72
-39
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+5
-3
pipelines/fast-test.yml
pipelines/fast-test.yml
+2
-2
test/ut/sdk/test_compression_utils.py
test/ut/sdk/test_compression_utils.py
+0
-1
test/ut/sdk/test_dependecy_aware.py
test/ut/sdk/test_dependecy_aware.py
+0
-1
test/ut/sdk/test_model_speedup.py
test/ut/sdk/test_model_speedup.py
+0
-1
test/ut/sdk/test_pruners.py
test/ut/sdk/test_pruners.py
+0
-1
No files found.
nni/common/graph_utils.py
View file @
f7b7edac
...
@@ -15,6 +15,7 @@ LIST_CONSTRUCT_KIND = 'prim::ListConstruct'
...
@@ -15,6 +15,7 @@ LIST_CONSTRUCT_KIND = 'prim::ListConstruct'
LIST_UNPACK_KIND
=
'prim::ListUnpack'
LIST_UNPACK_KIND
=
'prim::ListUnpack'
TUPLE_CONSTRUCT_KIND
=
'prim::TupleConstruct'
TUPLE_CONSTRUCT_KIND
=
'prim::TupleConstruct'
TUPLE_UNPACK_KIND
=
'prim::TupleUnpack'
TUPLE_UNPACK_KIND
=
'prim::TupleUnpack'
CONSTANT_KIND
=
'prim::Constant'
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -68,9 +69,11 @@ class TorchGraph:
...
@@ -68,9 +69,11 @@ class TorchGraph:
'Please provide model & dummy_input or the traced_model as inputs'
)
'Please provide model & dummy_input or the traced_model as inputs'
)
def
_trace
(
self
,
model
,
dummy_input
):
def
_trace
(
self
,
model
,
dummy_input
):
with
torch
.
onnx
.
set_training
(
model
,
False
):
training
=
model
.
training
model
.
eval
()
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
model
.
train
(
training
)
class
TorchProtoGraph
(
TorchGraph
):
class
TorchProtoGraph
(
TorchGraph
):
...
@@ -282,27 +285,35 @@ class TorchModuleGraph(TorchGraph):
...
@@ -282,27 +285,35 @@ class TorchModuleGraph(TorchGraph):
self
.
global_count
+=
1
self
.
global_count
+=
1
op_type
=
node
.
kind
()
op_type
=
node
.
kind
()
node_group
=
[
node
]
node_group
=
[
node
]
inputs
=
li
st
()
inputs
=
s
e
t
()
outputs
=
li
st
()
outputs
=
s
e
t
()
node_queue
=
queue
.
Queue
()
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
node_queue
.
put
(
node
)
while
not
node_queue
.
empty
():
while
not
node_queue
.
empty
():
curr_node
=
node_queue
.
get
()
curr_node
=
node_queue
.
get
()
for
_input
in
curr_node
.
inputs
():
for
_input
in
curr_node
.
inputs
():
if
_input
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
input_name
=
_input
.
debugName
()
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
if
input_name
in
output_to_node
:
predecessor_node
=
output_to_node
[
input_name
]
for
predecessor_node
in
output_to_node
[
input_name
]:
if
predecessor_node
in
nodes
:
if
not
self
.
_is_key_func
(
predecessor_node
):
if
not
self
.
_is_key_func
(
predecessor_node
):
if
predecessor_node
not
in
node_group
:
node_group
.
append
(
predecessor_node
)
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
else
:
else
:
inputs
.
append
(
input_name
)
inputs
.
add
(
input_name
)
else
:
inputs
.
add
(
input_name
)
else
:
else
:
inputs
.
a
ppen
d
(
input_name
)
inputs
.
a
d
d
(
input_name
)
for
output
in
node
.
outputs
():
for
output
in
node
.
outputs
():
outputs
.
append
(
output
.
debugName
())
if
output
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
outputs
.
add
(
output
.
debugName
())
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
inputs
,
outputs
=
outputs
,
key_node
=
node
)
node_group
,
inputs
=
list
(
inputs
)
,
outputs
=
list
(
outputs
)
,
key_node
=
node
)
return
nodepy
return
nodepy
def
_expand_module_node
(
self
,
node
,
node_name
,
unique_name
,
op_type
,
nodes
,
def
_expand_module_node
(
self
,
node
,
node_name
,
unique_name
,
op_type
,
nodes
,
...
@@ -342,36 +353,46 @@ class TorchModuleGraph(TorchGraph):
...
@@ -342,36 +353,46 @@ class TorchModuleGraph(TorchGraph):
if
not
op_type
:
if
not
op_type
:
op_type
=
node
.
kind
()
op_type
=
node
.
kind
()
node_group
=
[
node
]
node_group
=
[
node
]
inputs
=
li
st
()
inputs
=
s
e
t
()
outputs
=
li
st
()
outputs
=
s
e
t
()
node_queue
=
queue
.
Queue
()
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
node_queue
.
put
(
node
)
visited
=
{
node
}
visited
=
{
node
}
while
not
node_queue
.
empty
():
while
not
node_queue
.
empty
():
curr_node
=
node_queue
.
get
()
curr_node
=
node_queue
.
get
()
for
_input
in
curr_node
.
inputs
():
for
_input
in
curr_node
.
inputs
():
if
_input
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
input_name
=
_input
.
debugName
()
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
if
input_name
in
output_to_node
:
predecessor_node
=
output_to_node
[
input_name
]
for
predecessor_node
in
output_to_node
[
input_name
]:
if
predecessor_node
in
nodes
:
if
predecessor_node
not
in
visited
:
if
predecessor_node
not
in
visited
:
node_group
.
append
(
predecessor_node
)
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
visited
.
add
(
predecessor_node
)
visited
.
add
(
predecessor_node
)
else
:
else
:
inputs
.
append
(
input_name
)
inputs
.
add
(
input_name
)
else
:
inputs
.
add
(
input_name
)
for
_output
in
curr_node
.
outputs
():
for
_output
in
curr_node
.
outputs
():
if
_output
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
output_name
=
_output
.
debugName
()
output_name
=
_output
.
debugName
()
if
output_name
in
input_to_node
and
input_to_node
[
output_name
]
in
nodes
:
if
output_name
in
input_to_node
:
successor_node
=
input_to_node
[
output_name
]
for
successor_node
in
input_to_node
[
output_name
]:
if
successor_node
in
nodes
:
if
successor_node
not
in
visited
:
if
successor_node
not
in
visited
:
node_group
.
append
(
successor_node
)
node_group
.
append
(
successor_node
)
node_queue
.
put
(
successor_node
)
node_queue
.
put
(
successor_node
)
visited
.
add
(
successor_node
)
visited
.
add
(
successor_node
)
else
:
else
:
outputs
.
append
(
output_name
)
outputs
.
add
(
output_name
)
else
:
outputs
.
add
(
output_name
)
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
inputs
,
outputs
=
outputs
)
node_group
,
inputs
=
list
(
inputs
)
,
outputs
=
list
(
outputs
)
)
return
nodepy
return
nodepy
def
_extract_cat_info
(
self
,
node_group
,
cpp_node
):
def
_extract_cat_info
(
self
,
node_group
,
cpp_node
):
...
@@ -544,7 +565,7 @@ class TorchModuleGraph(TorchGraph):
...
@@ -544,7 +565,7 @@ class TorchModuleGraph(TorchGraph):
input_to_node
[
_input
].
append
(
node
)
input_to_node
[
_input
].
append
(
node
)
for
output
in
node
.
outputs
:
for
output
in
node
.
outputs
:
assert
not
output
in
output_to_node
,
\
assert
not
output
in
output_to_node
,
\
"One output cannot be generated by multiple nodes
"
"One output cannot be generated by multiple nodes
%s"
%
output
output_to_node
[
output
]
=
node
output_to_node
[
output
]
=
node
return
name_to_node
,
input_to_node
,
output_to_node
return
name_to_node
,
input_to_node
,
output_to_node
...
@@ -642,12 +663,22 @@ class TorchModuleGraph(TorchGraph):
...
@@ -642,12 +663,22 @@ class TorchModuleGraph(TorchGraph):
omit_useless_nodes
=
True
omit_useless_nodes
=
True
graph
=
self
.
trace
.
graph
graph
=
self
.
trace
.
graph
_logger
.
debug
(
graph
)
_logger
.
debug
(
graph
)
# build output mapping, from output debugName to its node
# build input/output mapping, from input/output debugName to its node
output_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
input_to_node
=
defaultdict
(
list
)
for
x
in
n
.
outputs
()}
output_to_node
=
defaultdict
(
list
)
# build input mapping, from input debugName to its node
for
node
in
graph
.
nodes
():
input_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
if
node
.
kind
()
==
CONSTANT_KIND
:
for
x
in
n
.
inputs
()}
continue
for
x
in
node
.
outputs
():
if
x
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
output_to_node
[
x
.
debugName
()].
append
(
node
)
assert
len
(
output_to_node
[
x
.
debugName
()])
<=
1
,
"One output cannot be generated by multiple nodes %s"
%
x
.
debugName
()
for
x
in
node
.
inputs
():
if
x
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
input_to_node
[
x
.
debugName
()].
append
(
node
)
# build module mapping, from module name to all nodes (as list) under this module scope
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes
=
defaultdict
(
list
)
module_to_nodes
=
defaultdict
(
list
)
# the mapping of function (non-module in forward) to nodes, key is scope name
# the mapping of function (non-module in forward) to nodes, key is scope name
...
@@ -668,6 +699,8 @@ class TorchModuleGraph(TorchGraph):
...
@@ -668,6 +699,8 @@ class TorchModuleGraph(TorchGraph):
# associate module name with their trace graph nodes
# associate module name with their trace graph nodes
for
node
in
graph
.
nodes
():
for
node
in
graph
.
nodes
():
if
node
.
kind
()
==
CONSTANT_KIND
:
continue
module_name
=
self
.
_get_module_name
(
node
.
scopeName
())
module_name
=
self
.
_get_module_name
(
node
.
scopeName
())
if
module_name
in
self
.
leaf_modules
:
if
module_name
in
self
.
leaf_modules
:
module_to_nodes
[
module_name
].
append
(
node
)
module_to_nodes
[
module_name
].
append
(
node
)
...
...
nni/compression/pytorch/utils/mask_conflict.py
View file @
f7b7edac
...
@@ -36,9 +36,11 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
...
@@ -36,9 +36,11 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# this traced model.
# this traced model.
if
traced
is
None
:
if
traced
is
None
:
assert
model
is
not
None
and
dummy_input
is
not
None
assert
model
is
not
None
and
dummy_input
is
not
None
with
torch
.
onnx
.
set_training
(
model
,
False
):
training
=
model
.
training
# We need to trace the model in this way, else it will have problems
model
.
eval
()
# We need to trace the model in eval mode
traced
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
traced
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
model
.
train
(
training
)
fix_group_mask
=
GroupMaskConflict
(
masks
,
model
,
dummy_input
,
traced
)
fix_group_mask
=
GroupMaskConflict
(
masks
,
model
,
dummy_input
,
traced
)
masks
=
fix_group_mask
.
fix_mask
()
masks
=
fix_group_mask
.
fix_mask
()
...
...
pipelines/fast-test.yml
View file @
f7b7edac
...
@@ -34,7 +34,7 @@ jobs:
...
@@ -34,7 +34,7 @@ jobs:
set -e
set -e
sudo apt-get install -y pandoc
sudo apt-get install -y pandoc
python3 -m pip install -U --upgrade pygments
python3 -m pip install -U --upgrade pygments
python3 -m pip install -U torch==1.
5
.0+cpu torchvision==0.
6.0
+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U torch==1.
7
.0+cpu torchvision==0.
8.1
+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U tensorflow==2.3.1
python3 -m pip install -U tensorflow==2.3.1
python3 -m pip install -U keras==2.4.2
python3 -m pip install -U keras==2.4.2
python3 -m pip install -U gym onnx peewee thop
python3 -m pip install -U gym onnx peewee thop
...
@@ -96,7 +96,7 @@ jobs:
...
@@ -96,7 +96,7 @@ jobs:
-
script
:
|
-
script
:
|
set -e
set -e
python3 -m pip install -U torch==1.
3.1
+cpu torchvision==0.
4.2
+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U torch==1.
5.0
+cpu torchvision==0.
6.0
+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U tensorflow==1.15.2
python3 -m pip install -U tensorflow==1.15.2
python3 -m pip install -U keras==2.1.6
python3 -m pip install -U keras==2.1.6
python3 -m pip install -U gym onnx peewee
python3 -m pip install -U gym onnx peewee
...
...
test/ut/sdk/test_compression_utils.py
View file @
f7b7edac
...
@@ -61,7 +61,6 @@ channel_dependency_ground_truth = {
...
@@ -61,7 +61,6 @@ channel_dependency_ground_truth = {
unittest
.
TestLoader
.
sortTestMethodsUsing
=
None
unittest
.
TestLoader
.
sortTestMethodsUsing
=
None
@
unittest
.
skipIf
(
torch
.
__version__
>=
'1.6.0'
,
'not supported'
)
class
AnalysisUtilsTest
(
TestCase
):
class
AnalysisUtilsTest
(
TestCase
):
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.3.0"
,
"not supported"
)
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.3.0"
,
"not supported"
)
def
test_channel_dependency
(
self
):
def
test_channel_dependency
(
self
):
...
...
test/ut/sdk/test_dependecy_aware.py
View file @
f7b7edac
...
@@ -47,7 +47,6 @@ def generate_random_sparsity_v2(model):
...
@@ -47,7 +47,6 @@ def generate_random_sparsity_v2(model):
return
cfg_list
return
cfg_list
@
unittest
.
skipIf
(
torch
.
__version__
>=
'1.6.0'
,
'not supported'
)
class
DependencyawareTest
(
TestCase
):
class
DependencyawareTest
(
TestCase
):
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.3.0"
,
"not supported"
)
@
unittest
.
skipIf
(
torch
.
__version__
<
"1.3.0"
,
"not supported"
)
def
test_dependency_aware_pruning
(
self
):
def
test_dependency_aware_pruning
(
self
):
...
...
test/ut/sdk/test_model_speedup.py
View file @
f7b7edac
...
@@ -177,7 +177,6 @@ def channel_prune(model):
...
@@ -177,7 +177,6 @@ def channel_prune(model):
pruner
.
compress
()
pruner
.
compress
()
pruner
.
export_model
(
model_path
=
MODEL_FILE
,
mask_path
=
MASK_FILE
)
pruner
.
export_model
(
model_path
=
MODEL_FILE
,
mask_path
=
MASK_FILE
)
@
unittest
.
skipIf
(
torch
.
__version__
>=
'1.6.0'
,
'not supported'
)
class
SpeedupTestCase
(
TestCase
):
class
SpeedupTestCase
(
TestCase
):
def
test_speedup_vgg16
(
self
):
def
test_speedup_vgg16
(
self
):
prune_model_l1
(
vgg16
())
prune_model_l1
(
vgg16
())
...
...
test/ut/sdk/test_pruners.py
View file @
f7b7edac
...
@@ -264,7 +264,6 @@ class SimpleDataset:
...
@@ -264,7 +264,6 @@ class SimpleDataset:
def
__len__
(
self
):
def
__len__
(
self
):
return
1000
return
1000
@
unittest
.
skipIf
(
torch
.
__version__
>=
'1.6.0'
,
'not supported'
)
class
PrunerTestCase
(
TestCase
):
class
PrunerTestCase
(
TestCase
):
def
test_pruners
(
self
):
def
test_pruners
(
self
):
pruners_test
(
bias
=
True
)
pruners_test
(
bias
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment