Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9a68cdb2
"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "37e194f48a56723b4bc8d9e9674236cc7f90db3c"
Unverified
Commit
9a68cdb2
authored
Oct 11, 2021
by
Jiahang Xu
Committed by
GitHub
Oct 11, 2021
Browse files
Fix: refine shape attribute (#4214)
parent
50dc05d7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
57 additions
and
51 deletions
+57
-51
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+15
-13
nni/retiarii/converter/utils.py
nni/retiarii/converter/utils.py
+7
-6
nni/retiarii/graph.py
nni/retiarii/graph.py
+5
-3
nni/retiarii/operation.py
nni/retiarii/operation.py
+10
-9
nni/retiarii/operation_def/tf_op_def.py
nni/retiarii/operation_def/tf_op_def.py
+1
-1
test/ut/retiarii/mnist-tensorflow.json
test/ut/retiarii/mnist-tensorflow.json
+9
-9
test/ut/retiarii/test_convert_shape.py
test/ut/retiarii/test_convert_shape.py
+10
-10
No files found.
nni/retiarii/converter/graph_gen.py
View file @
9a68cdb2
...
@@ -707,19 +707,21 @@ class GraphConverterWithShape(GraphConverter):
...
@@ -707,19 +707,21 @@ class GraphConverterWithShape(GraphConverter):
for
ir_node
in
ir_model
.
get_nodes
():
for
ir_node
in
ir_model
.
get_nodes
():
if
ir_node
.
operation
.
parameters
is
None
:
if
ir_node
.
operation
.
parameters
is
None
:
ir_node
.
operation
.
parameters
=
{}
ir_node
.
operation
.
parameters
=
{}
ir_node
.
operation
.
parame
te
r
s
.
setdefault
(
'input_shape'
,
[])
ir_node
.
operation
.
attribu
tes
.
setdefault
(
'input_shape'
,
[])
ir_node
.
operation
.
parame
te
r
s
.
setdefault
(
'output_shape'
,
[])
ir_node
.
operation
.
attribu
tes
.
setdefault
(
'output_shape'
,
[])
def
_trace_module
(
self
,
module
,
module_name
,
ir_model
:
'Model'
,
dummy_input
):
def
_trace_module
(
self
,
module
,
module_name
,
ir_model
:
'Model'
,
dummy_input
):
# First, trace the whole graph
# First, trace the whole graph
tm_graph
=
self
.
_trace
(
module
,
dummy_input
)
tm_graph
=
self
.
_trace
(
module
,
dummy_input
)
for
node
in
tm_graph
.
nodes
():
for
node
in
tm_graph
.
nodes
():
parameters
=
_extract_info_from_trace_node
(
node
)
shape_parameters
,
parameters
=
_extract_info_from_trace_node
(
node
)
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
# '__module.convpool/__module.convpool.1/__module.convpool.1.conv'
ir_node
=
match_node
(
ir_model
,
node
,
module_name
)
ir_node
=
match_node
(
ir_model
,
node
,
module_name
)
if
ir_node
is
not
None
:
if
ir_node
is
not
None
:
ir_node
.
operation
.
parameters
.
update
(
parameters
)
ir_node
.
operation
.
attributes
.
update
(
shape_parameters
)
if
parameters
:
ir_node
.
operation
.
parameters
.
update
(
parameters
)
self
.
propagate_shape
(
ir_model
)
self
.
propagate_shape
(
ir_model
)
...
@@ -735,7 +737,7 @@ class GraphConverterWithShape(GraphConverter):
...
@@ -735,7 +737,7 @@ class GraphConverterWithShape(GraphConverter):
cand_name
=
build_cand_name
(
cand_name
,
submodule
.
label
)
cand_name
=
build_cand_name
(
cand_name
,
submodule
.
label
)
# TODO: Feed the exact input tensor if user provides input,
# TODO: Feed the exact input tensor if user provides input,
# in case the path changes according to input data.
# in case the path changes according to input data.
lc_inputs
=
[
torch
.
randn
(
shape
)
for
shape
in
lc_node
.
operation
.
parame
te
r
s
[
'input_shape'
]]
lc_inputs
=
[
torch
.
randn
(
shape
)
for
shape
in
lc_node
.
operation
.
attribu
tes
[
'input_shape'
]]
self
.
_trace_module
(
cand
,
cand_name
,
ir_model
,
lc_inputs
)
self
.
_trace_module
(
cand
,
cand_name
,
ir_model
,
lc_inputs
)
def
propagate_shape
(
self
,
ir_model
:
'Model'
):
def
propagate_shape
(
self
,
ir_model
:
'Model'
):
...
@@ -753,8 +755,8 @@ class GraphConverterWithShape(GraphConverter):
...
@@ -753,8 +755,8 @@ class GraphConverterWithShape(GraphConverter):
cand_node
=
ir_model
.
get_node_by_name
(
cand_name
)
cand_node
=
ir_model
.
get_node_by_name
(
cand_name
)
if
_without_shape_info
(
cand_node
):
if
_without_shape_info
(
cand_node
):
propagate_shape_for_graph
(
ir_model
.
graphs
[
cand_name
])
propagate_shape_for_graph
(
ir_model
.
graphs
[
cand_name
])
graph_node
.
operation
.
parame
te
r
s
[
'input_shape'
]
=
cand_node
.
operation
.
parame
te
r
s
[
'input_shape'
]
graph_node
.
operation
.
attribu
tes
[
'input_shape'
]
=
cand_node
.
operation
.
attribu
tes
[
'input_shape'
]
graph_node
.
operation
.
parame
te
r
s
[
'output_shape'
]
=
cand_node
.
operation
.
parame
te
r
s
[
'output_shape'
]
graph_node
.
operation
.
attribu
tes
[
'output_shape'
]
=
cand_node
.
operation
.
attribu
tes
[
'output_shape'
]
else
:
else
:
input_shape
=
[[]]
*
len
(
graph
.
input_node
.
operation
.
io_names
or
[])
input_shape
=
[[]]
*
len
(
graph
.
input_node
.
operation
.
io_names
or
[])
output_shape
=
[[]]
*
len
(
graph
.
output_node
.
operation
.
io_names
or
[])
output_shape
=
[[]]
*
len
(
graph
.
output_node
.
operation
.
io_names
or
[])
...
@@ -763,17 +765,17 @@ class GraphConverterWithShape(GraphConverter):
...
@@ -763,17 +765,17 @@ class GraphConverterWithShape(GraphConverter):
if
_without_shape_info
(
node
):
if
_without_shape_info
(
node
):
if
node
.
name
in
ir_model
.
graphs
:
if
node
.
name
in
ir_model
.
graphs
:
propagate_shape_for_graph
(
ir_model
.
graphs
[
node
.
name
])
propagate_shape_for_graph
(
ir_model
.
graphs
[
node
.
name
])
if
node
.
operation
.
parame
te
r
s
[
'input_shape'
]:
if
node
.
operation
.
attribu
tes
[
'input_shape'
]:
input_shape
[
edge
.
head_slot
or
0
]
=
node
.
operation
.
parame
te
r
s
[
'input_shape'
][
edge
.
tail_slot
or
0
]
input_shape
[
edge
.
head_slot
or
0
]
=
node
.
operation
.
attribu
tes
[
'input_shape'
][
edge
.
tail_slot
or
0
]
graph_node
.
operation
.
parame
te
r
s
[
'input_shape'
]
=
input_shape
graph_node
.
operation
.
attribu
tes
[
'input_shape'
]
=
input_shape
for
edge
in
graph
.
output_node
.
incoming_edges
:
for
edge
in
graph
.
output_node
.
incoming_edges
:
node
=
edge
.
head
node
=
edge
.
head
if
_without_shape_info
(
node
):
if
_without_shape_info
(
node
):
if
node
.
name
in
ir_model
.
graphs
:
if
node
.
name
in
ir_model
.
graphs
:
propagate_shape_for_graph
(
ir_model
.
graphs
[
node
.
name
])
propagate_shape_for_graph
(
ir_model
.
graphs
[
node
.
name
])
if
node
.
operation
.
parame
te
r
s
[
'output_shape'
]:
if
node
.
operation
.
attribu
tes
[
'output_shape'
]:
output_shape
[
edge
.
tail_slot
or
0
]
=
node
.
operation
.
parame
te
r
s
[
'output_shape'
][
edge
.
head_slot
or
0
]
output_shape
[
edge
.
tail_slot
or
0
]
=
node
.
operation
.
attribu
tes
[
'output_shape'
][
edge
.
head_slot
or
0
]
graph_node
.
operation
.
parame
te
r
s
[
'output_shape'
]
=
output_shape
graph_node
.
operation
.
attribu
tes
[
'output_shape'
]
=
output_shape
propagate_shape_for_graph
(
graph_node
.
graph
)
propagate_shape_for_graph
(
graph_node
.
graph
)
...
...
nni/retiarii/converter/utils.py
View file @
9a68cdb2
...
@@ -56,15 +56,16 @@ def _extract_info_from_trace_node(trace_node):
...
@@ -56,15 +56,16 @@ def _extract_info_from_trace_node(trace_node):
if
shape
:
if
shape
:
output_shape
.
append
(
shape
)
output_shape
.
append
(
shape
)
parameters
=
{
shape_
parameters
=
{
'input_shape'
:
input_shape
,
'input_shape'
:
input_shape
,
'output_shape'
:
output_shape
,
'output_shape'
:
output_shape
,
}
}
if
trace_node
.
kind
()
==
'aten::cat'
:
if
trace_node
.
kind
()
==
'aten::cat'
:
parameters
[
'dim'
]
=
inputs
[
1
].
toIValue
()
parameters
=
{
'dim'
:
inputs
[
1
].
toIValue
()}
return
shape_parameters
,
parameters
return
parameters
else
:
return
shape_parameters
,
None
def
is_layerchoice_node
(
ir_node
:
Node
):
def
is_layerchoice_node
(
ir_node
:
Node
):
...
@@ -100,7 +101,7 @@ def match_node(ir_model: Model, torch_node, prefix=''):
...
@@ -100,7 +101,7 @@ def match_node(ir_model: Model, torch_node, prefix=''):
graph
=
ir_model
.
graphs
.
get
(
full_name
)
graph
=
ir_model
.
graphs
.
get
(
full_name
)
if
graph
is
not
None
:
if
graph
is
not
None
:
for
node
in
graph
.
get_nodes_by_type
(
torch_node
.
kind
()):
for
node
in
graph
.
get_nodes_by_type
(
torch_node
.
kind
()):
if
not
node
.
operation
.
parame
te
r
s
[
'input_shape'
]:
if
not
node
.
operation
.
attribu
tes
[
'input_shape'
]:
return
node
return
node
return
None
return
None
else
:
else
:
...
@@ -108,4 +109,4 @@ def match_node(ir_model: Model, torch_node, prefix=''):
...
@@ -108,4 +109,4 @@ def match_node(ir_model: Model, torch_node, prefix=''):
def
_without_shape_info
(
node
:
Node
):
def
_without_shape_info
(
node
:
Node
):
return
not
node
.
operation
.
parame
te
r
s
[
'input_shape'
]
and
not
node
.
operation
.
parame
te
r
s
[
'output_shape'
]
return
not
node
.
operation
.
attribu
tes
[
'input_shape'
]
and
not
node
.
operation
.
attribu
tes
[
'output_shape'
]
nni/retiarii/graph.py
View file @
9a68cdb2
...
@@ -603,16 +603,18 @@ class Node:
...
@@ -603,16 +603,18 @@ class Node:
@
staticmethod
@
staticmethod
def
_load
(
graph
:
Graph
,
name
:
str
,
ir
:
Any
)
->
'Node'
:
def
_load
(
graph
:
Graph
,
name
:
str
,
ir
:
Any
)
->
'Node'
:
if
ir
[
'operation'
][
'type'
]
==
'_cell'
:
if
ir
[
'operation'
][
'type'
]
==
'_cell'
:
op
=
Cell
(
ir
[
'operation'
][
'cell_name'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}))
op
=
Cell
(
ir
[
'operation'
][
'cell_name'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}),
attributes
=
ir
[
'operation'
].
get
(
'attributes'
,
{}))
else
:
else
:
op
=
Operation
.
new
(
ir
[
'operation'
][
'type'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}))
op
=
Operation
.
new
(
ir
[
'operation'
][
'type'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}),
attributes
=
ir
[
'operation'
].
get
(
'attributes'
,
{}))
node
=
Node
(
graph
,
uid
(),
name
,
op
)
node
=
Node
(
graph
,
uid
(),
name
,
op
)
if
'label'
in
ir
:
if
'label'
in
ir
:
node
.
update_label
(
ir
[
'label'
])
node
.
update_label
(
ir
[
'label'
])
return
node
return
node
def
_dump
(
self
)
->
Any
:
def
_dump
(
self
)
->
Any
:
ret
=
{
'operation'
:
{
'type'
:
self
.
operation
.
type
,
'parameters'
:
self
.
operation
.
parameters
}}
ret
=
{
'operation'
:
{
'type'
:
self
.
operation
.
type
,
'parameters'
:
self
.
operation
.
parameters
,
'attributes'
:
self
.
operation
.
attributes
}}
if
isinstance
(
self
.
operation
,
Cell
):
if
isinstance
(
self
.
operation
,
Cell
):
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
ret
[
'operation'
][
'cell_name'
]
=
self
.
operation
.
cell_name
if
self
.
label
is
not
None
:
if
self
.
label
is
not
None
:
...
...
nni/retiarii/operation.py
View file @
9a68cdb2
...
@@ -34,10 +34,11 @@ class Operation:
...
@@ -34,10 +34,11 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size).
Arbitrary key-value parameters (e.g. kernel_size).
"""
"""
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
],
_internal
:
bool
=
False
):
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{}
,
_internal
:
bool
=
False
,
attributes
:
Dict
[
str
,
Any
]
=
{}
):
assert
_internal
,
'`Operation()` is private, use `Operation.new()` instead'
assert
_internal
,
'`Operation()` is private, use `Operation.new()` instead'
self
.
type
:
str
=
type_name
self
.
type
:
str
=
type_name
self
.
parameters
:
Dict
[
str
,
Any
]
=
parameters
self
.
parameters
:
Dict
[
str
,
Any
]
=
parameters
self
.
attributes
:
Dict
[
str
,
Any
]
=
attributes
def
to_init_code
(
self
,
field
:
str
)
->
str
:
def
to_init_code
(
self
,
field
:
str
)
->
str
:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -52,9 +53,10 @@ class Operation:
...
@@ -52,9 +53,10 @@ class Operation:
return
True
return
True
@
staticmethod
@
staticmethod
def
new
(
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
cell_name
:
str
=
None
)
->
'Operation'
:
def
new
(
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
cell_name
:
str
=
None
,
if
parameters
is
None
:
attributes
:
Dict
[
str
,
Any
]
=
None
)
->
'Operation'
:
parameters
=
{}
parameters
=
parameters
or
{}
attributes
=
attributes
or
{}
if
type_name
==
'_cell'
:
if
type_name
==
'_cell'
:
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return
Cell
(
cell_name
,
parameters
)
return
Cell
(
cell_name
,
parameters
)
...
@@ -67,7 +69,7 @@ class Operation:
...
@@ -67,7 +69,7 @@ class Operation:
cls
=
TensorFlowOperation
.
_find_subclass
(
type_name
)
cls
=
TensorFlowOperation
.
_find_subclass
(
type_name
)
else
:
else
:
raise
ValueError
(
f
'Unsupported framework:
{
debug_configs
.
framework
}
'
)
raise
ValueError
(
f
'Unsupported framework:
{
debug_configs
.
framework
}
'
)
return
cls
(
type_name
,
parameters
,
_internal
=
True
)
return
cls
(
type_name
,
parameters
,
_internal
=
True
,
attributes
=
attributes
)
@
classmethod
@
classmethod
def
_find_subclass
(
cls
,
subclass_name
):
def
_find_subclass
(
cls
,
subclass_name
):
...
@@ -205,12 +207,11 @@ class Cell(PyTorchOperation):
...
@@ -205,12 +207,11 @@ class Cell(PyTorchOperation):
No real usage. Exists for compatibility with base class.
No real usage. Exists for compatibility with base class.
"""
"""
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
):
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
None
,
attributes
:
Dict
[
str
,
Any
]
=
None
):
self
.
type
=
'_cell'
self
.
type
=
'_cell'
self
.
cell_name
=
cell_name
self
.
cell_name
=
cell_name
if
parameters
is
None
:
self
.
parameters
=
parameters
or
{}
parameters
=
{}
self
.
attributes
=
attributes
or
{}
self
.
parameters
=
parameters
def
_to_class_name
(
self
):
def
_to_class_name
(
self
):
# TODO: ugly, think about how to refactor this part
# TODO: ugly, think about how to refactor this part
...
...
nni/retiarii/operation_def/tf_op_def.py
View file @
9a68cdb2
...
@@ -5,7 +5,7 @@ from ..operation import TensorFlowOperation
...
@@ -5,7 +5,7 @@ from ..operation import TensorFlowOperation
class
Conv2D
(
TensorFlowOperation
):
class
Conv2D
(
TensorFlowOperation
):
def
__init__
(
self
,
type_name
,
parameters
,
_internal
):
def
__init__
(
self
,
type_name
,
parameters
,
_internal
,
attributes
=
None
):
if
'padding'
not
in
parameters
:
if
'padding'
not
in
parameters
:
parameters
[
'padding'
]
=
'same'
parameters
[
'padding'
]
=
'same'
super
().
__init__
(
type_name
,
parameters
,
_internal
)
super
().
__init__
(
type_name
,
parameters
,
_internal
)
test/ut/retiarii/mnist-tensorflow.json
View file @
9a68cdb2
...
@@ -4,11 +4,11 @@
...
@@ -4,11 +4,11 @@
"outputs"
:
[
"metric"
],
"outputs"
:
[
"metric"
],
"nodes"
:
{
"nodes"
:
{
"stem"
:
{
"operation"
:
{
"type"
:
"_cell"
,
"parameters"
:
{},
"cell_name"
:
"stem"
}},
"stem"
:
{
"operation"
:
{
"type"
:
"_cell"
,
"parameters"
:
{},
"attributes"
:
{},
"cell_name"
:
"stem"
}},
"flatten"
:
{
"operation"
:
{
"type"
:
"Flatten"
,
"parameters"
:
{}}},
"flatten"
:
{
"operation"
:
{
"type"
:
"Flatten"
,
"parameters"
:
{},
"attributes"
:
{}}},
"fc1"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
1024
,
"activation"
:
"relu"
}}},
"fc1"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
1024
,
"activation"
:
"relu"
}
,
"attributes"
:
{}
}},
"fc2"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
10
}}},
"fc2"
:
{
"operation"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
10
}
,
"attributes"
:
{}
}},
"softmax"
:
{
"operation"
:
{
"type"
:
"Softmax"
,
"parameters"
:
{}}}
"softmax"
:
{
"operation"
:
{
"type"
:
"Softmax"
,
"parameters"
:
{},
"attributes"
:
{}}}
},
},
"edges"
:
[
"edges"
:
[
...
@@ -23,10 +23,10 @@
...
@@ -23,10 +23,10 @@
"stem"
:
{
"stem"
:
{
"nodes"
:
{
"nodes"
:
{
"conv1"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
32
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}}},
"conv1"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
32
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}
,
"attributes"
:
{}
}},
"pool1"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}}},
"pool1"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}
,
"attributes"
:
{}
}},
"conv2"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
64
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}}},
"conv2"
:
{
"operation"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
64
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}
,
"attributes"
:
{}
}},
"pool2"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}}}
"pool2"
:
{
"operation"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}
,
"attributes"
:
{}
}}
},
},
"edges"
:
[
"edges"
:
[
...
...
test/ut/retiarii/test_convert_shape.py
View file @
9a68cdb2
...
@@ -24,12 +24,12 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
...
@@ -24,12 +24,12 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
conv_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)[
0
]
conv_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)[
0
]
relu_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.activation.ReLU'
)[
0
]
relu_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.activation.ReLU'
)[
0
]
pool_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.pooling.MaxPool2d'
)[
0
]
pool_node
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.pooling.MaxPool2d'
)[
0
]
self
.
assertEqual
(
conv_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
conv_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
conv_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
relu_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
pool_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
111
,
111
]])
self
.
assertEqual
(
pool_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
111
,
111
]])
def
test_nested_module
(
self
):
def
test_nested_module
(
self
):
class
ConvRelu
(
nn
.
Module
):
class
ConvRelu
(
nn
.
Module
):
...
@@ -54,8 +54,8 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
...
@@ -54,8 +54,8 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
# check if shape propagation works
# check if shape propagation works
cell_node
=
model_ir
.
get_nodes_by_type
(
'_cell'
)[
0
]
cell_node
=
model_ir
.
get_nodes_by_type
(
'_cell'
)[
0
]
self
.
assertEqual
(
cell_node
.
operation
.
parame
te
r
s
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
cell_node
.
operation
.
attribu
tes
.
get
(
'input_shape'
),
[[
1
,
3
,
224
,
224
]])
self
.
assertEqual
(
cell_node
.
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
cell_node
.
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
def
test_layerchoice
(
self
):
def
test_layerchoice
(
self
):
class
ConvNet
(
nn
.
Module
):
class
ConvNet
(
nn
.
Module
):
...
@@ -75,5 +75,5 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
...
@@ -75,5 +75,5 @@ class TestShape(unittest.TestCase, ConvertWithShapeMixin):
# check shape info of each candidates
# check shape info of each candidates
conv_nodes
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)
conv_nodes
=
model_ir
.
get_nodes_by_type
(
'__torch__.torch.nn.modules.conv.Conv2d'
)
self
.
assertEqual
(
conv_nodes
[
0
].
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
0
].
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
1
].
operation
.
parame
te
r
s
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
self
.
assertEqual
(
conv_nodes
[
1
].
operation
.
attribu
tes
.
get
(
'output_shape'
),
[[
1
,
1
,
222
,
222
]])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment