Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9a68cdb2
Unverified
Commit
9a68cdb2
authored
Oct 11, 2021
by
Jiahang Xu
Committed by
GitHub
Oct 11, 2021
Browse files
Fix: refine shape attribute (#4214)
parent
50dc05d7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
57 additions
and
51 deletions
+57
-51
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+15
-13
nni/retiarii/converter/utils.py
nni/retiarii/converter/utils.py
+7
-6
nni/retiarii/graph.py
nni/retiarii/graph.py
+5
-3
nni/retiarii/operation.py
nni/retiarii/operation.py
+10
-9
nni/retiarii/operation_def/tf_op_def.py
nni/retiarii/operation_def/tf_op_def.py
+1
-1
test/ut/retiarii/mnist-tensorflow.json
test/ut/retiarii/mnist-tensorflow.json
+9
-9
test/ut/retiarii/test_convert_shape.py
test/ut/retiarii/test_convert_shape.py
+10
-10
No files found.
nni/retiarii/converter/graph_gen.py
View file @
9a68cdb2
...
...
@@ -707,19 +707,21 @@ class GraphConverterWithShape(GraphConverter):
for
ir_node
in
ir_model
.
get_nodes
():
if
ir_node
.
operation
.
parameters
is
None
:
ir_node
.
operation
.
parameters
=
{}
ir_node
.
operation
.
parame
te
r
s
.
setdefault
(
'input_shape'
,
[])
ir_node
.
operation
.
parame
te
r
s
.
setdefault
(
'output_shape'
,
[])
ir_node
.
operation
.
attribu
tes
.
setdefault
(
'input_shape'
,
[])
ir_node
.
operation
.
attribu
tes
.
setdefault
(
'output_shape'
,
[])
def
_trace_module
(
self
,
module
,
module_name
,
ir_model
:
'Model'
,
dummy_input
):
# First, trace the whole graph
tm_graph
=
self
.
_trace
(
module
,
dummy_input
)
for
node
in
tm_graph
.
nodes
():
parameters
=
_extract_info_from_trace_node
(
node
)
shape_parameters
,
parameters
=
_extract_info_from_trace_node
(
node
)
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
ir_node
=
match_node
(
ir_model
,
node
,
module_name
)
if
ir_node
is
not
None
:
ir_node
.
operation
.
parameters
.
update
(
parameters
)
ir_node
.
operation
.
attributes
.
update
(
shape_parameters
)
if
parameters
:
ir_node
.
operation
.
parameters
.
update
(
parameters
)
self
.
propagate_shape
(
ir_model
)
...
...
@@ -735,7 +737,7 @@ class GraphConverterWithShape(GraphConverter):
cand_name
=
build_cand_name
(
cand_name
,
submodule
.
label
)
# TODO: Feed the exact input tensor if user provides input,
# in case the path changes according to input data.
lc_inputs
=
[
torch
.
randn
(
shape
)
for
shape
in
lc_node
.
operation
.
parame
te
r
s
[
'input_shape'
]]
lc_inputs
=
[
torch
.
randn
(
shape
)
for
shape
in
lc_node
.
operation
.
attribu
tes
[
'input_shape'
]]
self
.
_trace_module
(
cand
,
cand_name
,
ir_model
,
lc_inputs
)
def
propagate_shape
(
self
,
ir_model
:
'Model'
):
...
...
@@ -753,8 +755,8 @@ class GraphConverterWithShape(GraphConverter):
cand_node
=
ir_model
.
get_node_by_name
(
cand_name
)
if
_without_shape_info
(
cand_node
):
propagate_shape_for_graph
(
ir_model
.
graphs
[
cand_name
])
graph_node
.
operation
.
parame
te
r
s
[
'input_shape'
]
=
cand_node
.
operation
.
parame
te
r
s
[
'input_shape'
]
graph_node
.
operation
.
parame
te
r
s
[
'output_shape'
]
=
cand_node
.
operation
.
parame
te
r
s
[
'output_shape'
]
graph_node
.
operation
.
attribu
tes
[
'input_shape'
]
=
cand_node
.
operation
.
attribu
tes
[
'input_shape'
]
graph_node
.
operation
.
attribu
tes
[
'output_shape'
]
=
cand_node
.
operation
.
attribu
tes
[
'output_shape'
]
else
:
input_shape
=
[[]]
*
len
(
graph
.
input_node
.
operation
.
io_names
or
[])
output_shape
=
[[]]
*
len
(
graph
.
output_node
.
operation
.
io_names
or
[])
...
...
@@ -763,17 +765,17 @@ class GraphConverterWithShape(GraphConverter):
if
_without_shape_info
(
node
):
if
node
.
name
in
ir_model
.
graphs
:
propagate_shape_for_graph
(
ir_model
.
graphs
[
node
.
name
])
if
node
.
operation
.
parame
te
r
s
[
'input_shape'
]:
input_shape
[
edge
.
head_slot
or
0
]
=
node
.
operation
.
parame
te
r
s
[
'input_shape'
][
edge
.
tail_slot
or
0
]
graph_node
.
operation
.
parame
te
r
s
[
'input_shape'
]
=
input_shape
if
node
.
operation
.
attribu
tes
[
'input_shape'
]:
input_shape
[
edge
.
head_slot
or
0
]
=
node
.
operation
.
attribu
tes
[
'input_shape'
][
edge
.
tail_slot
or
0
]
graph_node
.
operation
.
attribu
tes
[
'input_shape'
]
=
input_shape
for
edge
in
graph
.
output_node
.
incoming_edges
:
node
=
edge
.
head
if
_without_shape_info
(
node
):
if
node
.
name
in
ir_model
.
graphs
:
propagate_shape_for_graph
(
ir_model
.
graphs
[
node
.
name
])
if
node
.
operation
.
parame
te
r
s
[
'output_shape'
]:
output_shape
[
edge
.
tail_slot
or
0
]
=
node
.
operation
.
parame
te
r
s
[
'output_shape'
][
edge
.
head_slot
or
0
]
graph_node
.
operation
.
parame
te
r
s
[
'output_shape'
]
=
output_shape
if
node
.
operation
.
attribu
tes
[
'output_shape'
]:
output_shape
[
edge
.
tail_slot
or
0
]
=
node
.
operation
.
attribu
tes
[
'output_shape'
][
edge
.
head_slot
or
0
]
graph_node
.
operation
.
attribu
tes
[
'output_shape'
]
=
output_shape
propagate_shape_for_graph
(
graph_node
.
graph
)
...
...
nni/retiarii/converter/utils.py
View file @
9a68cdb2
...
...
@@ -56,15 +56,16 @@ def _extract_info_from_trace_node(trace_node):
if
shape
:
output_shape
.
append
(
shape
)
parameters
=
{
shape_
parameters
=
{
'input_shape'
:
input_shape
,
'output_shape'
:
output_shape
,
}
if
trace_node
.
kind
()
==
'aten::cat'
:
parameters
[
'dim'
]
=
inputs
[
1
].
toIValue
()
return
parameters
parameters
=
{
'dim'
:
inputs
[
1
].
toIValue
()}
return
shape_parameters
,
parameters
else
:
return
shape_parameters
,
None
def
is_layerchoice_node
(
ir_node
:
Node
):
...
...
@@ -100,7 +101,7 @@ def match_node(ir_model: Model, torch_node, prefix=''):
graph
=
ir_model
.
graphs
.
get
(
full_name
)
if
graph
is
not
None
:
for
node
in
graph
.
get_nodes_by_type
(
torch_node
.
kind
()):
if
not
node
.
operation
.
parame
te
r
s
[
'input_shape'
]:
if
not
node
.
operation
.
attribu
tes
[
'input_shape'
]:
return
node
return
None
else
:
...
...
@@ -108,4 +109,4 @@ def match_node(ir_model: Model, torch_node, prefix=''):
def
_without_shape_info
(
node
:
Node
):
return
not
node
.
operation
.
parame
te
r
s
[
'input_shape'
]
and
not
node
.
operation
.
parame
te
r
s
[
'output_shape'
]
return
not
node
.
operation
.
attribu
tes
[
'input_shape'
]
and
not
node
.
operation
.
attribu
tes
[
'output_shape'
]
nni/retiarii/graph.py
View file @
9a68cdb2
...
...
@@ -603,16 +603,18 @@ class Node:
@
staticmethod
def
_load
(
graph
:
Graph
,
name
:
str
,
ir
:
Any
)
->
'Node'
:
if
ir
[
'operation'
][
'type'
]
==
'_cell'
:
op
=
Cell
(
ir
[
'operation'
][
'cell_name'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}))
op
=
Cell
(
ir
[
'operation'
][
'cell_name'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}),
attributes
=
ir
[
'operation'
].
get
(
'attributes'
,
{}))
else
:
op
=
Operation
.
new
(
ir
[
'operation'
][
'type'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}))
op
=
Operation
.
new
(
ir
[
'operation'
][
'type'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}),
attributes
=
ir
[
'operation'
].
get
(
'attributes'
,
{}))
node
=
Node
(
graph
,
uid
(),
name
,
op
)
if
'label'
in
ir
:
node
.
update_label
(
ir
[
'label'
])
return
node
def
_dump
(
self
)
->
Any
:
ret
=
{
'operation'
:
{
'type'
:
self
.
operation
.
type
,
'parameters'
:
self
.
operation
.
parameters
}}
ret
=
{
'operation'
:
{
'type'
:
self
.
operation
.
type
,
'parameters'
:
self
.
operation
.
parameters
,
'attributes'
:
self
.
operation
.
attributes
}}
if
isinstance
(
self
.
operation
,
Cell
):
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
if
self
.
label
is
not
None
:
...
...
nni/retiarii/operation.py
View file @
9a68cdb2
...
...
@@ -34,10 +34,11 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size).
"""
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
],
_internal
:
bool
=
False
):
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{}
,
_internal
:
bool
=
False
,
attributes
:
Dict
[
str
,
Any
]
=
{}
):
assert
_internal
,
'`Operation()` is private, use `Operation.new()` instead'
self
.
type
:
str
=
type_name
self
.
parameters
:
Dict
[
str
,
Any
]
=
parameters
self
.
attributes
:
Dict
[
str
,
Any
]
=
attributes
def
to_init_code
(
self
,
field
:
str
)
->
str
:
raise
NotImplementedError
()
...
...
@@ -52,9 +53,10 @@ class Operation:
return
True
@
staticmethod
def
new
(
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
cell_name
:
str
=
None
)
->
'Operation'
:
if
parameters
is
None
:
parameters
=
{}
def
new
(
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
cell_name
:
str
=
None
,
attributes
:
Dict
[
str
,
Any
]
=
None
)
->
'Operation'
:
parameters
=
parameters
or
{}
attributes
=
attributes
or
{}
if
type_name
==
'_cell'
:
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return
Cell
(
cell_name
,
parameters
)
...
...
@@ -67,7 +69,7 @@ class Operation:
cls
=
TensorFlowOperation
.
_find_subclass
(
type_name
)
else
:
raise
ValueError
(
f
'Unsupported framework:
{
debug_configs
.
framework
}
'
)
return
cls
(
type_name
,
parameters
,
_internal
=
True
)
return
cls
(
type_name
,
parameters
,
_internal
=
True
,
attributes
=
attributes
)
@
classmethod
def
_find_subclass
(
cls
,
subclass_name
):
...
...
@@ -205,12 +207,11 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class.
"""
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
):
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
attributes
:
Dict
[
str
,
Any
]
=
None
):
self
.
type
=
'_cell'
self
.
cell_name
=
cell_name
if
parameters
is
None
:
parameters
=
{}
self
.
parameters
=
parameters
self
.
parameters
=
parameters
or
{}
self
.
attributes
=
attributes
or
{}
def
_to_class_name
(
self
):
# TODO: ugly, think about how to refactor this part
...
...
nni/retiarii/operation_def/tf_op_def.py
View file @
9a68cdb2
...
...
@@ -5,7 +5,7 @@ from ..operation import TensorFlowOperation
class
Conv2D
(
TensorFlowOperation
):
def
__init__
(
self
,
type_name
,
parameters
,
_internal
):
def
__init__
(
self
,
type_name
,
parameters
,
_internal
,
attributes
=
None
):
if
'padding'
not
in
parameters
:
parameters
[
'padding'
]
=
'same'
super
().
__init__
(
type_name
,
parameters
,
_internal
)
test/ut/retiarii/mnist-tensorflow.json
View file @
9a68cdb2
...
...
@@ -4,11 +4,11 @@
"outputs"
:
[
"metric"
],
"nodes"
:
{
"stem"
:
{
"operation"
:
{
"type"
:
"_cell"
,
"parameters"
:
{},
"cell_name"
:
"stem"
}},
"flatten"
:
{
"operation"
:
{
"type"
:
"Flatten"
,
"parameters"
:
{}}},
"fc1"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
1024
,
"activation"
:
"relu"
}}},
"fc2"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
10
}}},
"softmax"
:
{
"operation"
:
{
"type"
:
"Softmax"
,
"parameters"
:
{}}}
"stem"
:
{
"operation"
:
{
"type"
:
"_cell"
,
"parameters"
:
{},
"attributes"
:
{},
"cell_name"
:
"stem"
}},
"flatten"
:
{
"operation"
:
{
"type"
:
"Flatten"
,
"parameters"
:
{},
"attributes"
:
{}}},
"fc1"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
1024
,
"activation"
:
"relu"
}
,
"attributes"
:
{}
}},
"fc2"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
10
}
,
"attributes"
:
{}
}},
"softmax"
:
{
"operation"
:
{
"type"
:
"Softmax"
,
"parameters"
:
{},
"attributes"
:
{}}}
},
"edges"
:
[
...
...
@@ -23,10 +23,10 @@
"stem"
:
{
"nodes"
:
{
"conv1"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
32
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}}},
"pool1"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}}},
"conv2"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
64
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}}},
"pool2"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}}}
"conv1"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
32
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}
,
"attributes"
:
{}
}},
"pool1"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}
,
"attributes"
:
{}
}},
"conv2"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
64
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}
,
"attributes"
:
{}
}},
"pool2"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}
,
"attributes"
:
{}
}}
},
"edges"
:
[
...
...
test/ut/retiarii/test_convert_shape.py
View file @
9a68cdb2
...
...
@@ -24,12 +24,12 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
conv_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)[
0
]
relu_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.activation.ReLU'
)[
0
]
pool_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.pooling.MaxPool2d'
)[
0
]
self
.
assertEqual
(
conv_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
conv_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
111
,
111
]])
self
.
assertEqual
(
conv_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
conv_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
111
,
111
]])
def
test_nested_module
(
self
):
class
ConvRelu
(
nn
.
Module
):
...
...
@@ -54,8 +54,8 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
# check if shape propagation works
cell_node
=
model_ir
.
get_nodes_by_type
(
'_cell'
)[
0
]
self
.
assertEqual
(
cell_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
cell_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
cell_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
cell_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
def
test_layerchoice
(
self
):
class
ConvNet
(
nn
.
Module
):
...
...
@@ -75,5 +75,5 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
# check shape info of each candidates
conv_nodes
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)
self
.
assertEqual
(
conv_nodes
[
0
].
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
1
].
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
0
].
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
1
].
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment