Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9a68cdb2
Unverified
Commit
9a68cdb2
authored
Oct 11, 2021
by
Jiahang Xu
Committed by
GitHub
Oct 11, 2021
Browse files
Fix: refine shape attribute (#4214)
parent
50dc05d7
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
57 additions
and
51 deletions
+57
-51
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+15
-13
nni/retiarii/converter/utils.py
nni/retiarii/converter/utils.py
+7
-6
nni/retiarii/graph.py
nni/retiarii/graph.py
+5
-3
nni/retiarii/operation.py
nni/retiarii/operation.py
+10
-9
nni/retiarii/operation_def/tf_op_def.py
nni/retiarii/operation_def/tf_op_def.py
+1
-1
test/ut/retiarii/mnist-tensorflow.json
test/ut/retiarii/mnist-tensorflow.json
+9
-9
test/ut/retiarii/test_convert_shape.py
test/ut/retiarii/test_convert_shape.py
+10
-10
No files found.
nni/retiarii/converter/graph_gen.py
View file @
9a68cdb2
...
@@ -707,18 +707,20 @@ class GraphConverterWithShape(GraphConverter):
...
@@ -707,18 +707,20 @@ class GraphConverterWithShape(GraphConverter):
for
ir_node
in
ir_model
.
get_nodes
():
for
ir_node
in
ir_model
.
get_nodes
():
if
ir_node
.
operation
.
parameters
is
None
:
if
ir_node
.
operation
.
parameters
is
None
:
ir_node
.
operation
.
parameters
=
{}
ir_node
.
operation
.
parameters
=
{}
ir_node
.
operation
.
parame
te
r
s
.
setdefault
(
'input_shape'
,
[])
ir_node
.
operation
.
attribu
tes
.
setdefault
(
'input_shape'
,
[])
ir_node
.
operation
.
parame
te
r
s
.
setdefault
(
'output_shape'
,
[])
ir_node
.
operation
.
attribu
tes
.
setdefault
(
'output_shape'
,
[])
def
_trace_module
(
self
,
module
,
module_name
,
ir_model
:
'Model'
,
dummy_input
):
def
_trace_module
(
self
,
module
,
module_name
,
ir_model
:
'Model'
,
dummy_input
):
# First, trace the whole graph
# First, trace the whole graph
tm_graph
=
self
.
_trace
(
module
,
dummy_input
)
tm_graph
=
self
.
_trace
(
module
,
dummy_input
)
for
node
in
tm_graph
.
nodes
():
for
node
in
tm_graph
.
nodes
():
parameters
=
_extract_info_from_trace_node
(
node
)
shape_parameters
,
parameters
=
_extract_info_from_trace_node
(
node
)
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
ir_node
=
match_node
(
ir_model
,
node
,
module_name
)
ir_node
=
match_node
(
ir_model
,
node
,
module_name
)
if
ir_node
is
not
None
:
if
ir_node
is
not
None
:
ir_node
.
operation
.
attributes
.
update
(
shape_parameters
)
if
parameters
:
ir_node
.
operation
.
parameters
.
update
(
parameters
)
ir_node
.
operation
.
parameters
.
update
(
parameters
)
self
.
propagate_shape
(
ir_model
)
self
.
propagate_shape
(
ir_model
)
...
@@ -735,7 +737,7 @@ class GraphConverterWithShape(GraphConverter):
...
@@ -735,7 +737,7 @@ class GraphConverterWithShape(GraphConverter):
cand_name
=
build_cand_name
(
cand_name
,
submodule
.
label
)
cand_name
=
build_cand_name
(
cand_name
,
submodule
.
label
)
# TODO: Feed the exact input tensor if user provides input,
# TODO: Feed the exact input tensor if user provides input,
# in case the path changes according to input data.
# in case the path changes according to input data.
lc_inputs
=
[
torch
.
randn
(
shape
)
for
shape
in
lc_node
.
operation
.
parame
te
r
s
[
'input_shape'
]]
lc_inputs
=
[
torch
.
randn
(
shape
)
for
shape
in
lc_node
.
operation
.
attribu
tes
[
'input_shape'
]]
self
.
_trace_module
(
cand
,
cand_name
,
ir_model
,
lc_inputs
)
self
.
_trace_module
(
cand
,
cand_name
,
ir_model
,
lc_inputs
)
def
propagate_shape
(
self
,
ir_model
:
'Model'
):
def
propagate_shape
(
self
,
ir_model
:
'Model'
):
...
@@ -753,8 +755,8 @@ class GraphConverterWithShape(GraphConverter):
...
@@ -753,8 +755,8 @@ class GraphConverterWithShape(GraphConverter):
cand_node
=
ir_model
.
get_node_by_name
(
cand_name
)
cand_node
=
ir_model
.
get_node_by_name
(
cand_name
)
if
_without_shape_info
(
cand_node
):
if
_without_shape_info
(
cand_node
):
propagate_shape_for_graph
(
ir_model
.
graphs
[
cand_name
])
propagate_shape_for_graph
(
ir_model
.
graphs
[
cand_name
])
graph_node
.
operation
.
parame
te
r
s
[
'input_shape'
]
=
cand_node
.
operation
.
parame
te
r
s
[
'input_shape'
]
graph_node
.
operation
.
attribu
tes
[
'input_shape'
]
=
cand_node
.
operation
.
attribu
tes
[
'input_shape'
]
graph_node
.
operation
.
parame
te
r
s
[
'output_shape'
]
=
cand_node
.
operation
.
parame
te
r
s
[
'output_shape'
]
graph_node
.
operation
.
attribu
tes
[
'output_shape'
]
=
cand_node
.
operation
.
attribu
tes
[
'output_shape'
]
else
:
else
:
input_shape
=
[[]]
*
len
(
graph
.
input_node
.
operation
.
io_names
or
[])
input_shape
=
[[]]
*
len
(
graph
.
input_node
.
operation
.
io_names
or
[])
output_shape
=
[[]]
*
len
(
graph
.
output_node
.
operation
.
io_names
or
[])
output_shape
=
[[]]
*
len
(
graph
.
output_node
.
operation
.
io_names
or
[])
...
@@ -763,17 +765,17 @@ class GraphConverterWithShape(GraphConverter):
...
@@ -763,17 +765,17 @@ class GraphConverterWithShape(GraphConverter):
if
_without_shape_info
(
node
):
if
_without_shape_info
(
node
):
if
node
.
name
in
ir_model
.
graphs
:
if
node
.
name
in
ir_model
.
graphs
:
propagate_shape_for_graph
(
ir_model
.
graphs
[
node
.
name
])
propagate_shape_for_graph
(
ir_model
.
graphs
[
node
.
name
])
if
node
.
operation
.
parame
te
r
s
[
'input_shape'
]:
if
node
.
operation
.
attribu
tes
[
'input_shape'
]:
input_shape
[
edge
.
head_slot
or
0
]
=
node
.
operation
.
parame
te
r
s
[
'input_shape'
][
edge
.
tail_slot
or
0
]
input_shape
[
edge
.
head_slot
or
0
]
=
node
.
operation
.
attribu
tes
[
'input_shape'
][
edge
.
tail_slot
or
0
]
graph_node
.
operation
.
parame
te
r
s
[
'input_shape'
]
=
input_shape
graph_node
.
operation
.
attribu
tes
[
'input_shape'
]
=
input_shape
for
edge
in
graph
.
output_node
.
incoming_edges
:
for
edge
in
graph
.
output_node
.
incoming_edges
:
node
=
edge
.
head
node
=
edge
.
head
if
_without_shape_info
(
node
):
if
_without_shape_info
(
node
):
if
node
.
name
in
ir_model
.
graphs
:
if
node
.
name
in
ir_model
.
graphs
:
propagate_shape_for_graph
(
ir_model
.
graphs
[
node
.
name
])
propagate_shape_for_graph
(
ir_model
.
graphs
[
node
.
name
])
if
node
.
operation
.
parame
te
r
s
[
'output_shape'
]:
if
node
.
operation
.
attribu
tes
[
'output_shape'
]:
output_shape
[
edge
.
tail_slot
or
0
]
=
node
.
operation
.
parame
te
r
s
[
'output_shape'
][
edge
.
head_slot
or
0
]
output_shape
[
edge
.
tail_slot
or
0
]
=
node
.
operation
.
attribu
tes
[
'output_shape'
][
edge
.
head_slot
or
0
]
graph_node
.
operation
.
parame
te
r
s
[
'output_shape'
]
=
output_shape
graph_node
.
operation
.
attribu
tes
[
'output_shape'
]
=
output_shape
propagate_shape_for_graph
(
graph_node
.
graph
)
propagate_shape_for_graph
(
graph_node
.
graph
)
...
...
nni/retiarii/converter/utils.py
View file @
9a68cdb2
...
@@ -56,15 +56,16 @@ def _extract_info_from_trace_node(trace_node):
...
@@ -56,15 +56,16 @@ def _extract_info_from_trace_node(trace_node):
if
shape
:
if
shape
:
output_shape
.
append
(
shape
)
output_shape
.
append
(
shape
)
parameters
=
{
shape_
parameters
=
{
'input_shape'
:
input_shape
,
'input_shape'
:
input_shape
,
'output_shape'
:
output_shape
,
'output_shape'
:
output_shape
,
}
}
if
trace_node
.
kind
()
==
'aten::cat'
:
if
trace_node
.
kind
()
==
'aten::cat'
:
parameters
[
'dim'
]
=
inputs
[
1
].
toIValue
()
parameters
=
{
'dim'
:
inputs
[
1
].
toIValue
()}
return
shape_parameters
,
parameters
return
parameters
else
:
return
shape_parameters
,
None
def
is_layerchoice_node
(
ir_node
:
Node
):
def
is_layerchoice_node
(
ir_node
:
Node
):
...
@@ -100,7 +101,7 @@ def match_node(ir_model: Model, torch_node, prefix=''):
...
@@ -100,7 +101,7 @@ def match_node(ir_model: Model, torch_node, prefix=''):
graph
=
ir_model
.
graphs
.
get
(
full_name
)
graph
=
ir_model
.
graphs
.
get
(
full_name
)
if
graph
is
not
None
:
if
graph
is
not
None
:
for
node
in
graph
.
get_nodes_by_type
(
torch_node
.
kind
()):
for
node
in
graph
.
get_nodes_by_type
(
torch_node
.
kind
()):
if
not
node
.
operation
.
parame
te
r
s
[
'input_shape'
]:
if
not
node
.
operation
.
attribu
tes
[
'input_shape'
]:
return
node
return
node
return
None
return
None
else
:
else
:
...
@@ -108,4 +109,4 @@ def match_node(ir_model: Model, torch_node, prefix=''):
...
@@ -108,4 +109,4 @@ def match_node(ir_model: Model, torch_node, prefix=''):
def
_without_shape_info
(
node
:
Node
):
def
_without_shape_info
(
node
:
Node
):
return
not
node
.
operation
.
parame
te
r
s
[
'input_shape'
]
and
not
node
.
operation
.
parame
te
r
s
[
'output_shape'
]
return
not
node
.
operation
.
attribu
tes
[
'input_shape'
]
and
not
node
.
operation
.
attribu
tes
[
'output_shape'
]
nni/retiarii/graph.py
View file @
9a68cdb2
...
@@ -603,16 +603,18 @@ class Node:
...
@@ -603,16 +603,18 @@ class Node:
@
staticmethod
@
staticmethod
def
_load
(
graph
:
Graph
,
name
:
str
,
ir
:
Any
)
->
'Node'
:
def
_load
(
graph
:
Graph
,
name
:
str
,
ir
:
Any
)
->
'Node'
:
if
ir
[
'operation'
][
'type'
]
==
'_cell'
:
if
ir
[
'operation'
][
'type'
]
==
'_cell'
:
op
=
Cell
(
ir
[
'operation'
][
'cell_name'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}))
op
=
Cell
(
ir
[
'operation'
][
'cell_name'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}),
attributes
=
ir
[
'operation'
].
get
(
'attributes'
,
{}))
else
:
else
:
op
=
Operation
.
new
(
ir
[
'operation'
][
'type'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}))
op
=
Operation
.
new
(
ir
[
'operation'
][
'type'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}),
attributes
=
ir
[
'operation'
].
get
(
'attributes'
,
{}))
node
=
Node
(
graph
,
uid
(),
name
,
op
)
node
=
Node
(
graph
,
uid
(),
name
,
op
)
if
'label'
in
ir
:
if
'label'
in
ir
:
node
.
update_label
(
ir
[
'label'
])
node
.
update_label
(
ir
[
'label'
])
return
node
return
node
def
_dump
(
self
)
->
Any
:
def
_dump
(
self
)
->
Any
:
ret
=
{
'operation'
:
{
'type'
:
self
.
operation
.
type
,
'parameters'
:
self
.
operation
.
parameters
}}
ret
=
{
'operation'
:
{
'type'
:
self
.
operation
.
type
,
'parameters'
:
self
.
operation
.
parameters
,
'attributes'
:
self
.
operation
.
attributes
}}
if
isinstance
(
self
.
operation
,
Cell
):
if
isinstance
(
self
.
operation
,
Cell
):
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
if
self
.
label
is
not
None
:
if
self
.
label
is
not
None
:
...
...
nni/retiarii/operation.py
View file @
9a68cdb2
...
@@ -34,10 +34,11 @@ class Operation:
...
@@ -34,10 +34,11 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size).
Arbitrary key-value parameters (e.g. kernel_size).
"""
"""
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
],
_internal
:
bool
=
False
):
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{}
,
_internal
:
bool
=
False
,
attributes
:
Dict
[
str
,
Any
]
=
{}
):
assert
_internal
,
'`Operation()` is private, use `Operation.new()` instead'
assert
_internal
,
'`Operation()` is private, use `Operation.new()` instead'
self
.
type
:
str
=
type_name
self
.
type
:
str
=
type_name
self
.
parameters
:
Dict
[
str
,
Any
]
=
parameters
self
.
parameters
:
Dict
[
str
,
Any
]
=
parameters
self
.
attributes
:
Dict
[
str
,
Any
]
=
attributes
def
to_init_code
(
self
,
field
:
str
)
->
str
:
def
to_init_code
(
self
,
field
:
str
)
->
str
:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -52,9 +53,10 @@ class Operation:
...
@@ -52,9 +53,10 @@ class Operation:
return
True
return
True
@
staticmethod
@
staticmethod
def
new
(
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
cell_name
:
str
=
None
)
->
'Operation'
:
def
new
(
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
cell_name
:
str
=
None
,
if
parameters
is
None
:
attributes
:
Dict
[
str
,
Any
]
=
None
)
->
'Operation'
:
parameters
=
{}
parameters
=
parameters
or
{}
attributes
=
attributes
or
{}
if
type_name
==
'_cell'
:
if
type_name
==
'_cell'
:
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return
Cell
(
cell_name
,
parameters
)
return
Cell
(
cell_name
,
parameters
)
...
@@ -67,7 +69,7 @@ class Operation:
...
@@ -67,7 +69,7 @@ class Operation:
cls
=
TensorFlowOperation
.
_find_subclass
(
type_name
)
cls
=
TensorFlowOperation
.
_find_subclass
(
type_name
)
else
:
else
:
raise
ValueError
(
f
'Unsupported framework:
{
debug_configs
.
framework
}
'
)
raise
ValueError
(
f
'Unsupported framework:
{
debug_configs
.
framework
}
'
)
return
cls
(
type_name
,
parameters
,
_internal
=
True
)
return
cls
(
type_name
,
parameters
,
_internal
=
True
,
attributes
=
attributes
)
@
classmethod
@
classmethod
def
_find_subclass
(
cls
,
subclass_name
):
def
_find_subclass
(
cls
,
subclass_name
):
...
@@ -205,12 +207,11 @@ class Cell(PyTorchOperation):
...
@@ -205,12 +207,11 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class.
No real usage. Exists for compatibility with base class.
"""
"""
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
):
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
attributes
:
Dict
[
str
,
Any
]
=
None
):
self
.
type
=
'_cell'
self
.
type
=
'_cell'
self
.
cell_name
=
cell_name
self
.
cell_name
=
cell_name
if
parameters
is
None
:
self
.
parameters
=
parameters
or
{}
parameters
=
{}
self
.
attributes
=
attributes
or
{}
self
.
parameters
=
parameters
def
_to_class_name
(
self
):
def
_to_class_name
(
self
):
# TODO: ugly, think about how to refactor this part
# TODO: ugly, think about how to refactor this part
...
...
nni/retiarii/operation_def/tf_op_def.py
View file @
9a68cdb2
...
@@ -5,7 +5,7 @@ from ..operation import TensorFlowOperation
...
@@ -5,7 +5,7 @@ from ..operation import TensorFlowOperation
class
Conv2D
(
TensorFlowOperation
):
class
Conv2D
(
TensorFlowOperation
):
def
__init__
(
self
,
type_name
,
parameters
,
_internal
):
def
__init__
(
self
,
type_name
,
parameters
,
_internal
,
attributes
=
None
):
if
'padding'
not
in
parameters
:
if
'padding'
not
in
parameters
:
parameters
[
'padding'
]
=
'same'
parameters
[
'padding'
]
=
'same'
super
().
__init__
(
type_name
,
parameters
,
_internal
)
super
().
__init__
(
type_name
,
parameters
,
_internal
)
test/ut/retiarii/mnist-tensorflow.json
View file @
9a68cdb2
...
@@ -4,11 +4,11 @@
...
@@ -4,11 +4,11 @@
"outputs"
:
[
"metric"
],
"outputs"
:
[
"metric"
],
"nodes"
:
{
"nodes"
:
{
"stem"
:
{
"operation"
:
{
"type"
:
"_cell"
,
"parameters"
:
{},
"cell_name"
:
"stem"
}},
"stem"
:
{
"operation"
:
{
"type"
:
"_cell"
,
"parameters"
:
{},
"attributes"
:
{},
"cell_name"
:
"stem"
}},
"flatten"
:
{
"operation"
:
{
"type"
:
"Flatten"
,
"parameters"
:
{}}},
"flatten"
:
{
"operation"
:
{
"type"
:
"Flatten"
,
"parameters"
:
{},
"attributes"
:
{}}},
"fc1"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
1024
,
"activation"
:
"relu"
}}},
"fc1"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
1024
,
"activation"
:
"relu"
}
,
"attributes"
:
{}
}},
"fc2"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
10
}}},
"fc2"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
10
}
,
"attributes"
:
{}
}},
"softmax"
:
{
"operation"
:
{
"type"
:
"Softmax"
,
"parameters"
:
{}}}
"softmax"
:
{
"operation"
:
{
"type"
:
"Softmax"
,
"parameters"
:
{},
"attributes"
:
{}}}
},
},
"edges"
:
[
"edges"
:
[
...
@@ -23,10 +23,10 @@
...
@@ -23,10 +23,10 @@
"stem"
:
{
"stem"
:
{
"nodes"
:
{
"nodes"
:
{
"conv1"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
32
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}}},
"conv1"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
32
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}
,
"attributes"
:
{}
}},
"pool1"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}}},
"pool1"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}
,
"attributes"
:
{}
}},
"conv2"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
64
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}}},
"conv2"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
64
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}
,
"attributes"
:
{}
}},
"pool2"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}}}
"pool2"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}
,
"attributes"
:
{}
}}
},
},
"edges"
:
[
"edges"
:
[
...
...
test/ut/retiarii/test_convert_shape.py
View file @
9a68cdb2
...
@@ -24,12 +24,12 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
...
@@ -24,12 +24,12 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
conv_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)[
0
]
conv_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)[
0
]
relu_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.activation.ReLU'
)[
0
]
relu_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.activation.ReLU'
)[
0
]
pool_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.pooling.MaxPool2d'
)[
0
]
pool_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.pooling.MaxPool2d'
)[
0
]
self
.
assertEqual
(
conv_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
conv_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
conv_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
111
,
111
]])
self
.
assertEqual
(
pool_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
111
,
111
]])
def
test_nested_module
(
self
):
def
test_nested_module
(
self
):
class
ConvRelu
(
nn
.
Module
):
class
ConvRelu
(
nn
.
Module
):
...
@@ -54,8 +54,8 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
...
@@ -54,8 +54,8 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
# check if shape propagation works
# check if shape propagation works
cell_node
=
model_ir
.
get_nodes_by_type
(
'_cell'
)[
0
]
cell_node
=
model_ir
.
get_nodes_by_type
(
'_cell'
)[
0
]
self
.
assertEqual
(
cell_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
cell_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
cell_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
cell_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
def
test_layerchoice
(
self
):
def
test_layerchoice
(
self
):
class
ConvNet
(
nn
.
Module
):
class
ConvNet
(
nn
.
Module
):
...
@@ -75,5 +75,5 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
...
@@ -75,5 +75,5 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
# check shape info of each candidates
# check shape info of each candidates
conv_nodes
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)
conv_nodes
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)
self
.
assertEqual
(
conv_nodes
[
0
].
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
0
].
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
1
].
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
1
].
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment