Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
58d5c2fa
Unverified
Commit
58d5c2fa
authored
Feb 22, 2021
by
QuanluZhang
Committed by
GitHub
Feb 22, 2021
Browse files
[retiarii] refactor of pytorch operators (#3365)
parent
59521d33
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
4020 additions
and
310 deletions
+4020
-310
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+35
-6
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+244
-120
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+0
-26
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+1
-1
nni/retiarii/operation.py
nni/retiarii/operation.py
+47
-52
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+407
-29
nni/retiarii/utils.py
nni/retiarii/utils.py
+8
-0
pipelines/fast-test.yml
pipelines/fast-test.yml
+3
-1
test/ut/retiarii/inject_nn.py
test/ut/retiarii/inject_nn.py
+280
-0
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+87
-75
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+283
-0
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+0
-0
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+1390
-0
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+1234
-0
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+1
-0
No files found.
nni/retiarii/codegen/pytorch.py
View file @
58d5c2fa
import
logging
import
logging
from
typing
import
List
from
typing
import
List
,
Tuple
,
Any
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
...
@@ -32,9 +32,26 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
...
@@ -32,9 +32,26 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
raise
IllegalGraphError
(
node
.
graph
,
'Node {} has bad inputs'
.
format
(
node
.
name
))
raise
IllegalGraphError
(
node
.
graph
,
'Node {} has bad inputs'
.
format
(
node
.
name
))
def
_format_inputs
(
node
:
Node
)
->
List
[
str
]:
def
_format_inputs
(
node
:
Node
)
->
Tuple
[
List
[
str
],
List
[
Any
]]:
"""
Format the inputs of a given node
Parameters
----------
node : Node
a graph node, get and format its inputs
Returns
-------
list
the list of input names
list
the list of input values, if an input is simple type, record its value,
otherwise the value is None
"""
edges
=
_sorted_incoming_edges
(
node
)
edges
=
_sorted_incoming_edges
(
node
)
inputs
=
[]
inputs
=
[]
inputs_value
=
[]
for
edge
in
edges
:
for
edge
in
edges
:
if
edge
.
head
.
name
==
'_inputs'
:
if
edge
.
head
.
name
==
'_inputs'
:
assert
isinstance
(
edge
.
head_slot
,
int
)
assert
isinstance
(
edge
.
head_slot
,
int
)
...
@@ -44,14 +61,21 @@ def _format_inputs(node: Node) -> List[str]:
...
@@ -44,14 +61,21 @@ def _format_inputs(node: Node) -> List[str]:
else
:
else
:
# when input has no name, e.g., forward(*_inputs)
# when input has no name, e.g., forward(*_inputs)
inputs
.
append
(
'_inputs[{}]'
.
format
(
edge
.
head_slot
))
inputs
.
append
(
'_inputs[{}]'
.
format
(
edge
.
head_slot
))
inputs_value
.
append
(
None
)
else
:
else
:
if
edge
.
head_slot
is
None
:
if
edge
.
head_slot
is
None
:
# when the input comes from a single-output operator
# when the input comes from a single-output operator
inputs
.
append
(
'{}'
.
format
(
edge
.
head
.
name
))
inputs
.
append
(
'{}'
.
format
(
edge
.
head
.
name
))
if
edge
.
head
.
operation
.
type
in
(
'prim::Constant'
,
'prim::GetAttr'
)
and
\
'value'
in
edge
.
head
.
operation
.
parameters
:
inputs_value
.
append
(
edge
.
head
.
operation
.
parameters
[
'value'
])
else
:
inputs_value
.
append
(
None
)
else
:
else
:
# when the input comes from a multi-output operator: needs to know which one it comes from
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs
.
append
(
'{}[{}]'
.
format
(
edge
.
head
.
name
,
edge
.
head_slot
))
inputs
.
append
(
'{}[{}]'
.
format
(
edge
.
head
.
name
,
edge
.
head_slot
))
return
inputs
inputs_value
.
append
(
None
)
return
inputs
,
inputs_value
def
_remove_prefix
(
names
,
graph_name
):
def
_remove_prefix
(
names
,
graph_name
):
...
@@ -80,6 +104,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
...
@@ -80,6 +104,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
node_codes
=
[]
node_codes
=
[]
for
node
in
nodes
:
for
node
in
nodes
:
if
node
.
operation
:
if
node
.
operation
:
if
node
.
operation
.
type
==
'shared'
:
continue
pkg_name
=
node
.
operation
.
get_import_pkg
()
pkg_name
=
node
.
operation
.
get_import_pkg
()
if
pkg_name
is
not
None
:
if
pkg_name
is
not
None
:
import_pkgs
.
add
(
pkg_name
)
import_pkgs
.
add
(
pkg_name
)
...
@@ -101,12 +127,15 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
...
@@ -101,12 +127,15 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
sorted_nodes
=
graph
.
topo_sort
()
sorted_nodes
=
graph
.
topo_sort
()
for
node
in
sorted_nodes
:
for
node
in
sorted_nodes
:
if
node
.
operation
:
if
node
.
operation
:
inputs
=
_format_inputs
(
node
)
inputs
,
inputs_value
=
_format_inputs
(
node
)
inputs
=
_remove_prefix
(
inputs
,
graph_name
)
inputs
=
_remove_prefix
(
inputs
,
graph_name
)
node_name
=
_remove_prefix
(
node
.
name
,
graph_name
)
node_name
=
_remove_prefix
(
node
.
name
,
graph_name
)
edge_codes
.
append
(
node
.
operation
.
to_forward_code
(
node_name
,
node_name
,
inputs
))
submodule_name
=
node_name
if
node
.
operation
.
type
==
'shared'
:
submodule_name
=
_remove_prefix
(
node
.
operation
.
parameters
[
'reference'
],
graph_name
)
edge_codes
.
append
(
node
.
operation
.
to_forward_code
(
submodule_name
,
node_name
,
inputs
,
inputs_value
))
output_names
=
_format_inputs
(
graph
.
output_node
)
output_names
,
_
=
_format_inputs
(
graph
.
output_node
)
output_names
=
_remove_prefix
(
output_names
,
graph_name
)
output_names
=
_remove_prefix
(
output_names
,
graph_name
)
if
not
output_names
:
if
not
output_names
:
raise
RuntimeError
(
'"forward" function should have return value(s): {}, {}, {}'
.
format
(
output_names
,
graph_name
,
graph
.
output_node
))
raise
RuntimeError
(
'"forward" function should have return value(s): {}, {}, {}'
.
format
(
output_names
,
graph_name
,
graph
.
output_node
))
...
...
nni/retiarii/converter/graph_gen.py
View file @
58d5c2fa
...
@@ -5,9 +5,9 @@ import torch
...
@@ -5,9 +5,9 @@ import torch
from
..graph
import
Graph
,
Model
,
Node
from
..graph
import
Graph
,
Model
,
Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..operation
import
Cell
from
..operation
import
Cell
,
Operation
from
..utils
import
get_records
from
..utils
import
get_records
from
.op_types
import
MODULE_EXCEPT_LIST
,
BasicOpsPT
,
OpTypeName
from
.op_types
import
MODULE_EXCEPT_LIST
,
OpTypeName
from
.utils
import
_convert_name
,
build_full_name
from
.utils
import
_convert_name
,
build_full_name
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -19,6 +19,33 @@ class GraphConverter:
...
@@ -19,6 +19,33 @@ class GraphConverter:
self
.
global_graph_id
=
0
self
.
global_graph_id
=
0
self
.
modules_arg
=
get_records
()
self
.
modules_arg
=
get_records
()
def
_add_edge_handle_source_node
(
self
,
_input
,
graph_inputs
,
ir_graph
,
output_remap
,
node_index
):
if
_input
in
graph_inputs
:
idx
=
graph_inputs
.
index
(
_input
)
src_node
=
ir_graph
.
input_node
src_node_idx
=
idx
elif
_input
in
output_remap
:
assert
output_remap
[
_input
].
kind
()
==
'aten::append'
predecessor_node
=
output_remap
[
_input
]
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
src_node_idx
=
None
src_node
=
node_index
[
predecessor_node
]
assert
isinstance
(
src_node
,
Node
)
else
:
predecessor_node
=
_input
.
node
()
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs
=
[
_output
for
_output
in
predecessor_node
.
outputs
()]
if
len
(
predecessor_outputs
)
==
1
:
idx
=
None
else
:
idx
=
predecessor_outputs
.
index
(
_input
)
ir_predecessor_node
=
node_index
[
predecessor_node
]
src_node_idx
=
idx
assert
isinstance
(
ir_predecessor_node
,
Node
)
src_node
=
ir_predecessor_node
return
src_node
,
src_node_idx
def
_add_edge
(
self
,
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
def
_add_edge
(
self
,
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
"""
"""
Parameters
Parameters
...
@@ -40,57 +67,40 @@ class GraphConverter:
...
@@ -40,57 +67,40 @@ class GraphConverter:
if
ignore_first
:
if
ignore_first
:
ignore_first
=
False
ignore_first
=
False
continue
continue
# handle source node
# handle source node
if
_input
in
graph_inputs
:
src_node
,
src_node_idx
=
self
.
_add_edge_handle_source_node
(
_input
,
graph_inputs
,
ir_graph
,
output_remap
,
node_index
)
idx
=
graph_inputs
.
index
(
_input
)
src_node
=
ir_graph
.
input_node
src_node_idx
=
idx
elif
_input
in
output_remap
:
assert
output_remap
[
_input
].
kind
()
==
'aten::append'
predecessor_node
=
output_remap
[
_input
]
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
src_node_idx
=
None
src_node
=
node_index
[
predecessor_node
]
assert
isinstance
(
src_node
,
Node
)
else
:
predecessor_node
=
_input
.
node
()
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs
=
[
_output
for
_output
in
predecessor_node
.
outputs
()]
if
len
(
predecessor_outputs
)
==
1
:
idx
=
None
else
:
idx
=
predecessor_outputs
.
index
(
_input
)
ir_predecessor_node
=
node_index
[
predecessor_node
]
src_node_idx
=
idx
assert
isinstance
(
ir_predecessor_node
,
Node
)
src_node
=
ir_predecessor_node
# handle destination node
# handle destination node
dst_node
=
new_node
dst_node
=
new_node
if
is_single_input
:
if
is_single_input
:
dst_node_idx
=
None
dst_node_idx
=
None
else
:
else
:
dst_node_idx
=
new_node_input_idx
dst_node_idx
=
new_node_input_idx
# create edge
# create edge
ir_graph
.
add_edge
(
head
=
(
src_node
,
src_node_idx
),
tail
=
(
dst_node
,
dst_node_idx
))
ir_graph
.
add_edge
(
head
=
(
src_node
,
src_node_idx
),
tail
=
(
dst_node
,
dst_node_idx
))
new_node_input_idx
+=
1
new_node_input_idx
+=
1
def
create_prim_constant_node
(
self
,
ir_graph
,
node
,
module_name
):
def
create_prim_constant_node
(
self
,
ir_graph
,
node
,
module_name
):
attrs
=
{}
# NOTE: compare with string not type, because the type is defined in pytorch C code.
if
node
.
outputsAt
(
0
).
toIValue
()
is
not
None
:
# `.kind()` can also be used here
attrs
=
{
'value'
:
node
.
outputsAt
(
0
).
toIValue
()}
if
node
.
outputsAt
(
0
).
type
().
str
()
==
'None'
:
attrs
=
{
'type'
:
'None'
}
else
:
attrs
=
{
'type'
:
node
.
outputsAt
(
0
).
type
().
str
(),
'value'
:
node
.
outputsAt
(
0
).
toIValue
()}
self
.
global_seq
+=
1
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Constant
,
self
.
global_seq
),
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Constant
,
self
.
global_seq
),
node
.
kind
(),
attrs
)
node
.
kind
(),
attrs
)
return
new_node
return
new_node
def
handle_prim_attr_node
(
self
,
node
):
def
handle_prim_attr_node
(
self
,
node
,
module
):
assert
node
.
hasAttribute
(
'name'
)
assert
node
.
hasAttribute
(
'name'
)
attrs
=
{
'name'
:
node
.
s
(
'name'
),
'input'
:
node
.
inputsAt
(
0
).
debugName
()}
value
=
None
if
node
.
inputsAt
(
0
).
debugName
()
==
'self'
:
_val
=
getattr
(
module
,
node
.
s
(
'name'
))
# TODO: serialize complex data type, and output proper error message
if
isinstance
(
_val
,
(
int
,
float
,
str
,
bool
)):
value
=
_val
attrs
=
{
'name'
:
node
.
s
(
'name'
),
'input'
:
node
.
inputsAt
(
0
).
debugName
(),
'value'
:
value
}
return
node
.
kind
(),
attrs
return
node
.
kind
(),
attrs
def
_remove_mangle
(
self
,
module_type_str
):
def
_remove_mangle
(
self
,
module_type_str
):
...
@@ -124,7 +134,10 @@ class GraphConverter:
...
@@ -124,7 +134,10 @@ class GraphConverter:
for
hidden_node
in
to_removes
:
for
hidden_node
in
to_removes
:
hidden_node
.
remove
()
hidden_node
.
remove
()
def
handle_graph_nodes
(
self
,
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
):
def
handle_graph_nodes
(
self
,
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
,
shared_module_index
=
None
):
"""
"""
Convert torch script node to our node ir, and build our graph ir
Convert torch script node to our node ir, and build our graph ir
...
@@ -142,6 +155,10 @@ class GraphConverter:
...
@@ -142,6 +155,10 @@ class GraphConverter:
the whole graph ir
the whole graph ir
ir_graph : Graph
ir_graph : Graph
the graph ir of ```module```
the graph ir of ```module```
shared_module_index : dict
it is used for knowing which module has been created an ir node,
if created and invoked again, then the new ir node can simply reference that ir node.
this way we can identify shared modules (i.e., one module invoked multiple times in `forward` function)
Returns
Returns
-------
-------
...
@@ -159,6 +176,8 @@ class GraphConverter:
...
@@ -159,6 +176,8 @@ class GraphConverter:
ir_graph
.
_add_input
(
_convert_name
(
_input
.
debugName
()))
ir_graph
.
_add_input
(
_convert_name
(
_input
.
debugName
()))
node_index
=
{}
# graph node to graph ir node
node_index
=
{}
# graph node to graph ir node
if
shared_module_index
is
None
:
shared_module_index
=
{}
# some node does not have output but it modifies a variable, for example aten::append
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %17 : Tensor[] = aten::append(%out.1, %16)
...
@@ -167,6 +186,7 @@ class GraphConverter:
...
@@ -167,6 +186,7 @@ class GraphConverter:
# key: tensor (%out.1), value: node (this node)
# key: tensor (%out.1), value: node (this node)
output_remap
=
{}
output_remap
=
{}
# ===================handle control flow: if===================
def
handle_if_condition
(
cond_tensor
):
def
handle_if_condition
(
cond_tensor
):
"""
"""
to calculate the condition, we only deal with the following op types by tracing back
to calculate the condition, we only deal with the following op types by tracing back
...
@@ -189,8 +209,45 @@ class GraphConverter:
...
@@ -189,8 +209,45 @@ class GraphConverter:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
==
{
right
}
)'
return
f
'(
{
left
}
==
{
right
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::le'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
<=
{
right
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::ge'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
>=
{
right
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::__not__'
:
value
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
return
f
'(not
{
value
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::Bool'
:
value
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
return
f
'bool(
{
value
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::__is__'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
is
{
right
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::__isnot__'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
is not
{
right
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::ne'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
!=
{
right
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::gt'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
>
{
right
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::lt'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
<
{
right
}
)'
elif
tensor
.
node
().
kind
()
==
'prim::If'
:
raise
RuntimeError
(
'Have not supported `if A and/or B`, please use two `if` statements instead.'
)
else
:
else
:
raise
RuntimeError
(
f
'Unsupported op type
{
tensor
.
node
().
kind
()
}
in if condition'
)
raise
RuntimeError
(
f
'Unsupported op type
{
tensor
.
node
().
kind
()
}
in if condition, '
'you are suggested to decorate the corresponding class with "@blackbox_module".'
)
expr
=
_generate_expr
(
cond_tensor
)
expr
=
_generate_expr
(
cond_tensor
)
return
eval
(
expr
)
return
eval
(
expr
)
...
@@ -217,8 +274,128 @@ class GraphConverter:
...
@@ -217,8 +274,128 @@ class GraphConverter:
last_block_node
=
None
last_block_node
=
None
for
node
in
blocks
[
chosen_block
].
nodes
():
for
node
in
blocks
[
chosen_block
].
nodes
():
last_block_node
=
handle_single_node
(
node
)
last_block_node
=
handle_single_node
(
node
)
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
'noop_identity'
,
self
.
global_seq
),
'noop_identity'
)
self
.
_add_edge
(
ir_graph
,
blocks
[
chosen_block
].
returnNode
(),
graph_inputs
,
node_index
,
new_node
,
output_remap
)
last_block_node
=
new_node
return
last_block_node
return
last_block_node
# ===================handle function call===================
def
handle_function_callmethod
(
node
):
# get and handle the first input, which should be an nn.Module
assert
node
.
hasAttribute
(
'name'
)
# NOTE: "forward__0" is hacky, LSTM instance is parsed to call forward__0 in torchscript
if
node
.
s
(
'name'
)
in
[
'forward'
,
'forward__0'
]:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
submodule
=
node
.
inputsAt
(
0
).
node
()
assert
submodule
.
kind
()
==
'prim::GetAttr'
assert
submodule
.
hasAttribute
(
'name'
)
submodule_name
=
submodule
.
s
(
'name'
)
if
submodule
.
inputsAt
(
0
).
debugName
()
==
'self'
:
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if
submodule
.
inputsAt
(
0
).
type
().
name
()
==
'ModuleList'
:
# handle ModuleList
predecessor
=
submodule
.
inputsAt
(
0
).
node
()
assert
predecessor
.
kind
()
==
'prim::GetAttr'
assert
predecessor
.
hasAttribute
(
'name'
)
assert
predecessor
.
inputsAt
(
0
).
debugName
()
==
'self'
predecessor_name
=
predecessor
.
s
(
'name'
)
# TODO: exchange submodule_name and predecessor_name
submodule_full_name
=
build_full_name
(
module_name
,
[
submodule_name
,
predecessor_name
])
predecessor_obj
=
getattr
(
module
,
predecessor_name
)
submodule_obj
=
getattr
(
predecessor_obj
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
predecessor_name
].
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
if
submodule_full_name
in
shared_module_index
:
# this module is invoked more than once, the ir node has already been created
# create a reference node for it.
# example: {"name": "conv2", "operation": {"type": "shared", "parameters": {"reference": "conv1"}}}
self
.
global_seq
+=
1
shared_node_name
=
build_full_name
(
submodule_full_name
,
''
,
self
.
global_seq
)
shared_type_operation
=
Operation
.
new
(
'shared'
,
{
'reference'
:
submodule_full_name
})
subcell
=
ir_graph
.
add_node
(
shared_node_name
,
shared_type_operation
)
else
:
# this module is processed for the first time, build cell for it
if
subgraph
is
None
:
# if we do not parse this module's graph, we create Node for this module
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
(
LayerChoice
,
InputChoice
)):
subcell
.
update_label
(
sub_m_attrs
[
'label'
])
else
:
# Graph already created, create Cell for it
new_cell
=
Cell
(
cell_name
=
submodule_full_name
,
parameters
=
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
shared_module_index
[
submodule_full_name
]
=
subcell
node_index
[
node
]
=
subcell
# connect the cell into graph
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
subcell
,
output_remap
,
ignore_first
=
True
)
else
:
# handle normal member function
assert
hasattr
(
script_module
,
node
.
s
(
'name'
))
# TODO: support non member functions
assert
node
.
inputsAt
(
0
).
debugName
()
==
'self'
script_method
=
getattr
(
script_module
,
node
.
s
(
'name'
))
# <class 'torch._C.ScriptMethod'>
# step #1: generate graph ir for this method
method_ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=-
100
,
name
=
'temp_graph'
,
_internal
=
True
)
method_node_index
=
self
.
handle_graph_nodes
(
script_module
,
script_method
.
graph
,
module
,
module_name
,
ir_model
,
method_ir_graph
,
shared_module_index
)
for
_output
in
script_method
.
graph
.
outputs
():
method_ir_graph
.
_add_output
(
_convert_name
(
_output
.
debugName
()))
predecessor_node_outputs
=
[
o
for
o
in
_output
.
node
().
outputs
()]
if
len
(
predecessor_node_outputs
)
==
1
:
src_node_idx
=
None
else
:
src_node_idx
=
predecessor_node_outputs
.
index
(
_output
)
method_ir_graph
.
add_edge
(
head
=
(
method_node_index
[
_output
.
node
()],
src_node_idx
),
tail
=
(
method_ir_graph
.
output_node
,
None
))
self
.
refine_graph
(
method_ir_graph
)
# step #2: merge this graph to its module graph
for
h_node
in
method_ir_graph
.
hidden_nodes
:
h_node
.
graph
=
ir_graph
ir_graph
.
hidden_nodes
.
append
(
h_node
)
for
edge
in
method_ir_graph
.
edges
:
edge
.
graph
=
ir_graph
if
edge
.
head
==
method_ir_graph
.
input_node
:
# this is a member method, 'self' is the first argument, thus +1
_input
=
node
.
inputsAt
(
edge
.
head_slot
+
1
)
src_node
,
src_node_idx
=
self
.
_add_edge_handle_source_node
(
_input
,
graph_inputs
,
ir_graph
,
output_remap
,
node_index
)
edge
.
head
=
src_node
edge
.
head_slot
=
src_node_idx
if
edge
.
tail
==
method_ir_graph
.
output_node
:
# since the following nodes have not been created, skip this edge
# edge.head is the output node of this method
# TODO: check whether there could be multiple output nodes???
node_index
[
node
]
=
edge
.
head
continue
ir_graph
.
edges
.
append
(
edge
)
# ===================handle each single node===================
def
handle_single_node
(
node
):
def
handle_single_node
(
node
):
"""
"""
Parameters
Parameters
...
@@ -232,69 +409,7 @@ class GraphConverter:
...
@@ -232,69 +409,7 @@ class GraphConverter:
the created node ir
the created node ir
"""
"""
if
node
.
kind
()
==
'prim::CallMethod'
:
if
node
.
kind
()
==
'prim::CallMethod'
:
# get and handle the first input, which should be an nn.Module
handle_function_callmethod
(
node
)
assert
node
.
hasAttribute
(
'name'
)
if
node
.
s
(
'name'
)
==
'forward'
:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
submodule
=
node
.
inputsAt
(
0
).
node
()
assert
submodule
.
kind
()
==
'prim::GetAttr'
assert
submodule
.
hasAttribute
(
'name'
)
submodule_name
=
submodule
.
s
(
'name'
)
if
submodule
.
inputsAt
(
0
).
debugName
()
==
'self'
:
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if
submodule
.
inputsAt
(
0
).
type
().
name
()
==
'ModuleList'
:
# handle ModuleList
predecessor
=
submodule
.
inputsAt
(
0
).
node
()
assert
predecessor
.
kind
()
==
'prim::GetAttr'
assert
predecessor
.
hasAttribute
(
'name'
)
assert
predecessor
.
inputsAt
(
0
).
debugName
()
==
'self'
predecessor_name
=
predecessor
.
s
(
'name'
)
# FIXME: exchange
submodule_full_name
=
build_full_name
(
module_name
,
[
submodule_name
,
predecessor_name
])
predecessor_obj
=
getattr
(
module
,
predecessor_name
)
submodule_obj
=
getattr
(
predecessor_obj
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
predecessor_name
].
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
# TODO: match subgraph with maintained graphs
# build cell
if
subgraph
is
None
:
# if we do not parse this module's graph, we create Node for this module
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
(
LayerChoice
,
InputChoice
)):
subcell
.
update_label
(
sub_m_attrs
[
'label'
])
else
:
# Graph already created, create Cell for it
new_cell
=
Cell
(
cell_name
=
submodule_full_name
,
parameters
=
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
node_index
[
node
]
=
subcell
# connect the cell into graph
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
subcell
,
output_remap
,
ignore_first
=
True
)
else
:
raise
RuntimeError
(
'unsupported CallMethod {}'
.
format
(
node
.
s
(
'name'
)))
elif
node
.
kind
()
==
'prim::CallFunction'
:
elif
node
.
kind
()
==
'prim::CallFunction'
:
func_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
func_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
func
=
node
.
inputsAt
(
0
).
node
()
func
=
node
.
inputsAt
(
0
).
node
()
...
@@ -310,30 +425,14 @@ class GraphConverter:
...
@@ -310,30 +425,14 @@ class GraphConverter:
elif
node
.
kind
()
==
'prim::Constant'
:
elif
node
.
kind
()
==
'prim::Constant'
:
new_node
=
self
.
create_prim_constant_node
(
ir_graph
,
node
,
module_name
)
new_node
=
self
.
create_prim_constant_node
(
ir_graph
,
node
,
module_name
)
node_index
[
node
]
=
new_node
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::ListConstruct'
:
elif
node
.
kind
()
in
[
'prim::ListConstruct'
,
'prim::ListUnpack'
,
'prim::TupleConstruct'
,
'prim::TupleUnpack'
]
:
self
.
global_seq
+=
1
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
ListConstruct
,
self
.
global_seq
),
node
.
kind
())
prim_op_name
=
node
.
kind
().
split
(
'::'
)[
-
1
]
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
prim_op_name
,
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
new_node
node_index
[
node
]
=
new_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
elif
node
.
kind
()
==
'prim::TupleConstruct'
:
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
TupleConstruct
,
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
new_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
elif
node
.
kind
()
==
'aten::append'
:
self
.
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
output_remap
[
node
.
inputsAt
(
0
)]
=
node
elif
node
.
kind
().
startswith
(
'aten::'
):
# handle aten::XXX
self
.
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
elif
node
.
kind
()
==
'prim::GetAttr'
:
elif
node
.
kind
()
==
'prim::GetAttr'
:
node_type
,
attrs
=
self
.
handle_prim_attr_node
(
node
)
node_type
,
attrs
=
self
.
handle_prim_attr_node
(
node
,
module
)
self
.
global_seq
+=
1
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Attr
,
self
.
global_seq
),
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Attr
,
self
.
global_seq
),
node_type
,
attrs
)
node_type
,
attrs
)
...
@@ -345,6 +444,26 @@ class GraphConverter:
...
@@ -345,6 +444,26 @@ class GraphConverter:
elif
node
.
kind
()
==
'prim::Loop'
:
elif
node
.
kind
()
==
'prim::Loop'
:
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise
RuntimeError
(
'Loop has not been supported yet!'
)
raise
RuntimeError
(
'Loop has not been supported yet!'
)
elif
node
.
kind
().
startswith
(
'prim::'
):
self
.
global_seq
+=
1
prim_op_name
=
node
.
kind
().
replace
(
'::'
,
'__'
)
prim_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
prim_op_name
,
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
prim_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
prim_node
,
output_remap
)
elif
node
.
kind
()
==
'aten::append'
:
self
.
global_seq
+=
1
aten_op_name
=
node
.
kind
().
replace
(
'::'
,
'__'
)
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
aten_op_name
,
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
output_remap
[
node
.
inputsAt
(
0
)]
=
node
elif
node
.
kind
().
startswith
(
'aten::'
):
# handle aten::XXX
self
.
global_seq
+=
1
aten_op_name
=
node
.
kind
().
replace
(
'::'
,
'__'
)
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
aten_op_name
,
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
else
:
else
:
raise
RuntimeError
(
'Unsupported kind: {}'
.
format
(
node
.
kind
()))
raise
RuntimeError
(
'Unsupported kind: {}'
.
format
(
node
.
kind
()))
...
@@ -378,6 +497,11 @@ class GraphConverter:
...
@@ -378,6 +497,11 @@ class GraphConverter:
new_slice_node
=
ir_graph
.
add_node
(
build_full_name
(
head_node
.
name
,
'merged'
),
OpTypeName
.
MergedSlice
)
new_slice_node
=
ir_graph
.
add_node
(
build_full_name
(
head_node
.
name
,
'merged'
),
OpTypeName
.
MergedSlice
)
if
len
(
head_node
.
incoming_edges
)
==
4
:
if
len
(
head_node
.
incoming_edges
)
==
4
:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
for
edge
in
head_node
.
incoming_edges
:
edge
.
tail
=
new_slice_node
for
edge
in
head_node
.
outgoing_edges
:
edge
.
head
=
new_slice_node
ir_graph
.
hidden_nodes
.
remove
(
head_node
)
break
break
assert
len
(
head_node
.
incoming_edges
)
==
5
assert
len
(
head_node
.
incoming_edges
)
==
5
for
edge
in
head_node
.
incoming_edges
:
for
edge
in
head_node
.
incoming_edges
:
...
@@ -478,7 +602,7 @@ class GraphConverter:
...
@@ -478,7 +602,7 @@ class GraphConverter:
m_attrs
=
self
.
_handle_valuechoice
(
module
)
m_attrs
=
self
.
_handle_valuechoice
(
module
)
elif
original_type_name
==
OpTypeName
.
Placeholder
:
elif
original_type_name
==
OpTypeName
.
Placeholder
:
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
elif
original_type_name
in
torch
.
nn
.
__dict__
:
elif
module
.
__class__
.
__module__
.
startswith
(
'torch.nn'
)
and
original_type_name
in
torch
.
nn
.
__dict__
:
# this is a basic module from pytorch, no need to parse its graph
# this is a basic module from pytorch, no need to parse its graph
assert
id
(
module
)
in
self
.
modules_arg
,
f
'
{
original_type_name
}
arguments are not recorded'
assert
id
(
module
)
in
self
.
modules_arg
,
f
'
{
original_type_name
}
arguments are not recorded'
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
...
...
nni/retiarii/converter/op_types.py
View file @
58d5c2fa
...
@@ -9,34 +9,8 @@ class OpTypeName(str, Enum):
...
@@ -9,34 +9,8 @@ class OpTypeName(str, Enum):
"""
"""
Attr
=
'Attr'
Attr
=
'Attr'
Constant
=
'Constant'
Constant
=
'Constant'
ListConstruct
=
'ListConstruct'
TupleConstruct
=
'TupleConstruct'
LayerChoice
=
'LayerChoice'
LayerChoice
=
'LayerChoice'
InputChoice
=
'InputChoice'
InputChoice
=
'InputChoice'
ValueChoice
=
'ValueChoice'
ValueChoice
=
'ValueChoice'
Placeholder
=
'Placeholder'
Placeholder
=
'Placeholder'
MergedSlice
=
'MergedSlice'
MergedSlice
=
'MergedSlice'
# deal with aten op
BasicOpsPT
=
{
'aten::mean'
:
'Mean'
,
'aten::relu'
:
'Relu'
,
'aten::add'
:
'Add'
,
'aten::__getitem__'
:
'getitem'
,
'aten::append'
:
'Append'
,
'aten::len'
:
'Len'
,
'aten::slice'
:
'Slice'
,
'aten::cat'
:
'Cat'
,
'aten::size'
:
'Size'
,
'aten::view'
:
'View'
,
'aten::reshape'
:
'Reshape'
,
'aten::eq'
:
'Eq'
,
'aten::Bool'
:
'Bool'
,
'aten::empty'
:
'Empty'
,
'aten::zeros'
:
'Zeros'
,
'aten::chunk'
:
'Chunk'
,
'aten::add_'
:
'Add_'
# %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
BasicOpsTF
=
{}
nni/retiarii/nn/pytorch/mutator.py
View file @
58d5c2fa
...
@@ -45,7 +45,7 @@ class ValueChoiceMutator(Mutator):
...
@@ -45,7 +45,7 @@ class ValueChoiceMutator(Mutator):
chosen
=
self
.
choice
(
self
.
candidates
)
chosen
=
self
.
choice
(
self
.
candidates
)
for
node
in
self
.
nodes
:
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
target
=
model
.
get_node_by_name
(
node
.
name
)
target
.
update_operation
(
'prim::Constant'
,
{
'value'
:
chosen
})
target
.
update_operation
(
'prim::Constant'
,
{
'type'
:
type
(
chosen
).
__name__
,
'value'
:
chosen
})
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
...
...
nni/retiarii/operation.py
View file @
58d5c2fa
...
@@ -83,6 +83,31 @@ class Operation:
...
@@ -83,6 +83,31 @@ class Operation:
class
PyTorchOperation
(
Operation
):
class
PyTorchOperation
(
Operation
):
@
classmethod
def
_find_subclass
(
cls
,
subclass_name
):
if
cls
.
to_class_name
(
subclass_name
)
is
not
None
:
subclass_name
=
'ModuleOperator'
if
cls
.
is_functional
(
subclass_name
):
subclass_name
=
'FunctionalOperator'
for
subclass
in
cls
.
__subclasses__
():
if
hasattr
(
subclass
,
'_ori_type_name'
)
and
\
subclass_name
in
subclass
.
_ori_type_name
:
return
subclass
return
cls
@
classmethod
def
to_class_name
(
cls
,
type_name
)
->
str
:
if
type_name
.
startswith
(
'__torch__.'
):
return
type_name
[
len
(
'__torch__.'
):]
elif
type_name
.
startswith
(
'__mutated__.'
):
return
type_name
[
len
(
'__mutated__.'
):]
else
:
return
None
@
classmethod
def
is_functional
(
cls
,
type_name
)
->
bool
:
return
type_name
.
startswith
(
'Function.'
)
def
_to_class_name
(
self
)
->
str
:
def
_to_class_name
(
self
)
->
str
:
if
self
.
type
.
startswith
(
'__torch__.'
):
if
self
.
type
.
startswith
(
'__torch__.'
):
return
self
.
type
[
len
(
'__torch__.'
):]
return
self
.
type
[
len
(
'__torch__.'
):]
...
@@ -106,59 +131,27 @@ class PyTorchOperation(Operation):
...
@@ -106,59 +131,27 @@ class PyTorchOperation(Operation):
return
f
'self.
{
field
}
=
{
self
.
_to_class_name
()
}
(
{
kw_params
}
)'
return
f
'self.
{
field
}
=
{
self
.
_to_class_name
()
}
(
{
kw_params
}
)'
return
None
return
None
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
from
.converter.op_types
import
OpTypeName
"""
if
self
.
_to_class_name
()
is
not
None
:
Parameters
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
----------
elif
self
.
type
.
startswith
(
'Function.'
):
field : str
func_name
=
self
.
type
[
len
(
'Function.'
):]
the name of member submodule
return
f
'
{
output
}
= F.
{
func_name
}
(
{
", "
.
join
(
inputs
)
}
)'
output : str
elif
self
.
type
==
'prim::Constant'
:
the output name (lvalue) of this line of code
if
self
.
parameters
:
inputs : List[str]
value
=
self
.
parameters
[
'value'
]
variables used in this line of code
else
:
inputs_value : List[Any]
value
=
None
some variables are actually constant, their real values are recorded in ```inputs_value```.
return
f
'
{
output
}
=
{
value
}
'
if not constant, we simply put None at the corresponding index
elif
self
.
type
==
'prim::ListConstruct'
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
Returns
elif
self
.
type
==
'prim::TupleConstruct'
:
-------
return
f
'
{
output
}
= (
{
", "
.
join
(
inputs
)
}
)'
str
elif
self
.
type
==
'prim::GetAttr'
:
generated code line
return
f
"
{
output
}
=
{
self
.
parameters
[
'input'
]
}
.
{
self
.
parameters
[
'name'
]
}
"
"""
elif
self
.
type
==
'aten::mean'
:
if
self
.
type
==
'aten::slice'
:
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
elif
self
.
type
==
'aten::__getitem__'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
elif
self
.
type
==
'aten::append'
:
assert
len
(
inputs
)
==
2
return
f
'_,
{
output
}
=
{
inputs
[
0
]
}
.append(
{
inputs
[
1
]
}
),
{
inputs
[
0
]
}
'
elif
self
.
type
==
'aten::cat'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::add'
:
return
f
'
{
output
}
= '
+
' + '
.
join
(
inputs
)
elif
self
.
type
==
OpTypeName
.
MergedSlice
:
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
slices
=
[]
dim
=
int
((
len
(
inputs
)
-
1
)
/
4
)
for
i
in
range
(
dim
):
slices
.
append
(
f
'
{
inputs
[
i
*
4
+
2
]
}
:
{
inputs
[
i
*
4
+
3
]
}
:
{
inputs
[
i
*
4
+
4
]
}
'
)
slice_str
=
','
.
join
(
slices
)
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
slice_str
}
]'
elif
self
.
type
==
'aten::size'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.size(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::view'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::reshape'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.reshape(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::slice'
:
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
elif
self
.
type
==
'aten::Bool'
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
else
:
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
...
@@ -212,6 +205,8 @@ class Cell(PyTorchOperation):
...
@@ -212,6 +205,8 @@ class Cell(PyTorchOperation):
# TODO: ugly, think about how to refactor this part
# TODO: ugly, think about how to refactor this part
return
_convert_name
(
self
.
cell_name
)
return
_convert_name
(
self
.
cell_name
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
class
_IOPseudoOperation
(
Operation
):
class
_IOPseudoOperation
(
Operation
):
"""
"""
...
...
nni/retiarii/operation_def/torch_op_def.py
View file @
58d5c2fa
from
typing
import
(
Any
,
List
)
import
torch
from
..operation
import
PyTorchOperation
from
..operation
import
PyTorchOperation
class
relu
(
PyTorchOperation
):
mem_format
=
[
def
to_init_code
(
self
,
field
):
'torch.contiguous_format'
,
# 0
return
''
'torch.preserve_format'
,
# 1
'torch.channels_last'
,
# 2
]
def
to_forward_code
(
self
,
field
,
output
,
*
inputs
)
->
str
:
# this snippet is copied from torch/onnx/symbolic_helper.py,
assert
len
(
inputs
)
==
1
# the original definition is in c10/core/ScalarType.h
return
f
'
{
output
}
= nn.functional.relu(
{
inputs
[
0
]
}
)'
# This indicates each scalar type's corresponding
scalar_type_to_pytorch_type
=
[
'torch.uint8'
,
# 0
'torch.int8'
,
# 1
'torch.short'
,
# 2
'torch.int'
,
# 3
'torch.int64'
,
# 4
'torch.half'
,
# 5
'torch.float'
,
# 6
'torch.double'
,
# 7
'torch.complex32'
,
# 8
'torch.complex64'
,
# 9
'torch.complex128'
,
# 10
'torch.bool'
,
# 11
]
class
NoOpIdentity
(
PyTorchOperation
):
"""
this operator type is added by us
"""
_ori_type_name
=
[
'noop_identity'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
=
{
", "
.
join
(
inputs
)
}
'
class
Flatten
(
PyTorchOperation
):
class
ModuleOperator
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
_ori_type_name
=
[
'ModuleOperator'
,
'shared'
]
return
''
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
def
to_forward_code
(
self
,
field
,
output
,
*
inputs
)
->
str
:
class
FunctionalOperator
(
PyTorchOperation
):
assert
len
(
inputs
)
==
1
_ori_type_name
=
[
'FunctionalOperator'
]
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
0
]
}
.size(0), -1)'
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
func_name
=
self
.
type
[
len
(
'Function.'
):]
if
not
hasattr
(
torch
.
nn
.
functional
,
func_name
):
raise
RuntimeError
(
'For now, we only support calling independent functions from `torch.nn.functional`, '
f
'
{
func_name
}
is not in it.'
)
return
f
'
{
output
}
= F.
{
func_name
}
(
{
", "
.
join
(
inputs
)
}
)'
class
PrimConstant
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::Constant'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if
self
.
parameters
[
'type'
]
==
'None'
:
return
f
'
{
output
}
= None'
elif
self
.
parameters
[
'type'
]
in
(
'int'
,
'float'
,
'bool'
,
'int[]'
):
return
f
'
{
output
}
=
{
self
.
parameters
[
"value"
]
}
'
elif
self
.
parameters
[
'type'
]
==
'str'
:
str_val
=
self
.
parameters
[
"value"
]
return
f
'
{
output
}
= "
{
str_val
}
"'
elif
self
.
parameters
[
'type'
]
==
'Device'
:
value
=
self
.
parameters
[
'value'
]
return
f
'
{
output
}
= torch.device("
{
value
}
")'
else
:
raise
RuntimeError
(
f
'unsupported type of prim::Constant:
{
self
.
parameters
[
"type"
]
}
'
)
class
PrimListConstruct
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::ListConstruct'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
class
PrimListUnpack
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::ListUnpack'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
class
ToDevice
(
PyTorchOperation
):
class
PrimTupleConstruct
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
_ori_type_name
=
[
'prim::TupleConstruct'
]
return
''
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= (
{
", "
.
join
(
inputs
)
}
)'
def
to_forward_code
(
self
,
field
,
output
,
inputs
)
->
str
:
class
PrimTupleUnpack
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::TupleUnpack'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
# have single output here, because the following code uses index to access the unpacked values
assert
len
(
inputs
)
==
1
assert
len
(
inputs
)
==
1
return
f
"
{
output
}
=
{
inputs
[
0
]
}
.to('
{
self
.
parameters
[
'device'
]
}
')"
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
class
PrimGetAttr
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::GetAttr'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
if
self
.
parameters
[
'value'
]
is
not
None
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'value'
]
}
"
else
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'input'
]
}
.
{
self
.
parameters
[
'name'
]
}
"
class
Dense
(
PyTorchOperation
):
class
SimpleMember
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
_ori_type_name
=
[
'prim::is_cuda'
,
'prim::data'
]
return
f
"self.
{
field
}
= nn.Linear(
{
self
.
parameters
[
'in_features'
]
}
,
{
self
.
parameters
[
'out_features'
]
}
)"
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
member_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
member_name
}
'
def
to_forward_code
(
self
,
field
,
output
,
*
inputs
)
->
str
:
class
AtenContiguous
(
PyTorchOperation
):
assert
len
(
inputs
)
==
1
_ori_type_name
=
[
'aten::contiguous'
]
return
f
'
{
output
}
= self.
{
field
}
(
{
inputs
[
0
]
}
)'
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
# defined in pytorch/c10/core/MemoryFormat.h
assert
inputs_value
[
1
]
in
[
0
,
1
,
2
]
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.contiguous(memory_format=
{
mem_format
[
inputs_value
[
1
]]
}
)'
class
AtenGetitem
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::__getitem__'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
class
Softmax
(
PyTorchOperation
):
class
AtenAppend
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
_ori_type_name
=
[
'aten::append'
]
return
''
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
assert
len
(
inputs
)
==
2
return
f
'_,
{
output
}
=
{
inputs
[
0
]
}
.append(
{
inputs
[
1
]
}
),
{
inputs
[
0
]
}
'
def
to_forward_code
(
self
,
field
,
output
,
*
inputs
)
->
str
:
class
MergedSlice
(
PyTorchOperation
):
assert
len
(
inputs
)
==
1
_ori_type_name
=
[
'MergedSlice'
]
return
f
'
{
output
}
= F.softmax(
{
inputs
[
0
]
}
, -1)'
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
if
(
len
(
inputs
)
-
1
)
%
4
==
0
:
slices
=
[]
dim
=
int
((
len
(
inputs
)
-
1
)
/
4
)
for
i
in
range
(
dim
):
slices
.
append
(
f
'
{
inputs
[
i
*
4
+
2
]
}
:
{
inputs
[
i
*
4
+
3
]
}
:
{
inputs
[
i
*
4
+
4
]
}
'
)
slice_str
=
','
.
join
(
slices
)
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
slice_str
}
]'
elif
len
(
inputs
)
==
4
:
# this case is for simple list
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
:
{
inputs
[
2
]
}
:
{
inputs
[
3
]
}
]'
else
:
raise
RuntimeError
(
'Unsupported slice pattern'
)
# the following Aten classes means these aten ops are not in torch.Tensor
class
AtenBool
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::Bool'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
class
AtenNot
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::__not__'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= not
{
inputs
[
0
]
}
'
class
AtenCat
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::cat'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
#====================================
class
AtenTensors
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::full'
,
'aten::full_like'
,
'aten::empty_like'
,
'aten::ones_like'
,
'aten::zeros_like'
,
'aten::rand'
,
'aten::randn'
,
'aten::scalar_tensor'
,
'aten::new_full'
,
'aten::new_empty'
,
'aten::new_zeros'
,
'aten::arange'
,
'aten::tensor'
,
'aten::ones'
,
'aten::zeros'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
self
.
type
)
# match number of inputs
overloaded_defs
=
[
len
(
s
.
arguments
)
for
s
in
schemas
]
matched
=
overloaded_defs
.
index
(
len
(
inputs
))
args_list
=
[]
for
idx
,
arg
in
enumerate
(
schemas
[
matched
].
arguments
):
if
arg
.
name
==
'dtype'
:
arg_str
=
f
'dtype=
{
scalar_type_to_pytorch_type
[
inputs_value
[
idx
]]
}
'
if
inputs_value
[
idx
]
is
not
None
else
''
elif
arg
.
name
==
'layout'
:
if
inputs_value
[
idx
]
is
not
None
:
arg_str
=
f
'layout=torch.strided'
print
(
'Warning: only support `torch.strided` for now!!!'
)
else
:
arg_str
=
''
elif
arg
.
name
==
'device'
:
arg_str
=
f
'device=torch.device(
{
inputs
[
idx
]
}
)'
if
inputs_value
[
idx
]
is
not
None
else
''
elif
arg
.
name
==
'memory_format'
:
arg_str
=
f
'memory_format=
{
mem_format
[
inputs_value
[
idx
]]
}
'
if
inputs_value
[
idx
]
is
not
None
else
''
elif
arg
.
name
==
'pin_memory'
:
# TODO: deal with this argument
continue
elif
arg
.
name
==
'requires_grad'
:
arg_str
=
f
'requires_grad=
{
inputs
[
idx
]
}
'
if
inputs_value
[
idx
]
else
''
elif
str
(
arg
.
type
).
startswith
(
'Optional['
):
arg_str
=
f
'
{
arg
.
name
}
=
{
inputs
[
idx
]
}
'
else
:
arg_str
=
f
'
{
inputs
[
idx
]
}
'
if
arg_str
!=
''
:
args_list
.
append
(
arg_str
)
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
if
hasattr
(
torch
,
op_name
):
return
f
'
{
output
}
= torch.
{
op_name
}
(
{
", "
.
join
(
args_list
)
}
)'
else
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
op_name
}
(
{
", "
.
join
(
args_list
[
1
:])
}
)'
#====================================
class
AtenFloordiv
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::floordiv'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
//
{
inputs
[
1
]
}
'
class
AtenLen
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::len'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= len(
{
inputs
[
0
]
}
)'
class
AtenIntImplicit
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::IntImplicit'
,
'aten::Float'
,
'aten::Int'
,
'aten::ScalarImplicit'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
if
self
.
type
.
endswith
(
'Implicit'
):
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
elif
self
.
type
==
'aten::Int'
:
return
f
'
{
output
}
= int(
{
inputs
[
0
]
}
)'
elif
self
.
type
==
'aten::Float'
:
return
f
'
{
output
}
= float(
{
inputs
[
0
]
}
)'
class
AtenIndex
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::index'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
ManuallyChooseDef
=
{
'aten::flatten'
:
[(
'start_dim'
,
'int'
,
'0'
),
(
'end_dim'
,
'int'
,
'-1'
)],
'aten::split'
:
[(
'split_size'
,
'int'
,
'None'
),
(
'dim'
,
'int'
,
'0'
)]
}
TensorOpExceptions
=
{
'aten::sub'
:
lambda
output
,
inputs
:
f
'
{
output
}
=
{
inputs
[
0
]
}
-
{
inputs
[
1
]
}
'
,
# example: x.size(1) - 3
'aten::add'
:
lambda
output
,
inputs
:
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
# example: input.shape[0] + 5
}
TorchOpExclude
=
[
'aten::Size'
,
'aten::as_tensor'
,
'aten::device'
,
'aten::manual_seed'
,
'aten::quantized_gru'
,
'aten::quantized_lstm'
,
'aten::save'
,
'aten::tensor'
,
'aten::wait'
]
def
_hidden
(
name
):
return
name
.
startswith
(
'_'
)
and
not
name
.
startswith
(
'__'
)
def
_emit_args
(
args
):
# filter out the `out` argument here
return
[(
arg
.
name
,
str
(
arg
.
type
),
str
(
arg
.
default_value
))
for
arg
in
args
]
# if arg.name != 'out'
def
_get_tensor_ops
():
def
is_tensor_method
(
schema
):
if
len
(
schema
.
arguments
)
==
0
:
return
False
self
=
schema
.
arguments
[
0
]
if
self
.
name
!=
'self'
:
return
False
if
not
self
.
type
.
isSubtypeOf
(
torch
.
_C
.
TensorType
.
get
()):
return
False
return
True
op_args
=
{}
# discover methods
for
elem
in
dir
(
torch
.
Tensor
):
if
not
_hidden
(
elem
):
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
"aten::"
+
elem
)
for
schema
in
schemas
:
if
is_tensor_method
(
schema
):
op_name
=
'aten::'
+
elem
args
=
_emit_args
(
schema
.
arguments
[
1
:])
if
op_name
in
op_args
:
op_args
[
op_name
].
append
(
args
)
else
:
op_args
[
op_name
]
=
[
args
]
return
op_args
.
keys
(),
op_args
def
_get_torch_ops
():
torch_op_args
=
{}
for
mod
in
torch
.
jit
.
_builtins
.
_modules_containing_builtins
:
name
=
mod
.
__name__
if
name
==
'torch._C._nn'
:
continue
# only process 'torch.XXX'
for
elem
in
dir
(
mod
):
builtin
=
torch
.
jit
.
_builtins
.
_find_builtin
(
getattr
(
mod
,
elem
))
if
builtin
is
not
None
:
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
builtin
)
for
schema
in
schemas
:
# remove _tan but not __and__
if
not
_hidden
(
elem
):
op_name
=
'aten::'
+
elem
if
len
(
schema
.
arguments
)
>
0
and
schema
.
arguments
[
0
].
name
==
'self'
:
continue
args
=
_emit_args
(
schema
.
arguments
)
if
op_name
in
torch_op_args
:
torch_op_args
[
op_name
].
append
(
args
)
else
:
torch_op_args
[
op_name
]
=
[
args
]
return
torch_op_args
.
keys
(),
torch_op_args
def
_get_torch_ops_exclude_tensor_ops
():
tensor_op_names
,
_
=
_get_tensor_ops
()
torch_op_names
,
torch_ops
=
_get_torch_ops
()
torch_exclude_ops
=
{}
for
name
in
torch_op_names
:
if
name
not
in
tensor_op_names
:
if
name
not
in
TorchOpExclude
:
# exclude the ops that are not in
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
torch_exclude_ops
[
name
]
=
torch_ops
[
name
]
return
torch_exclude_ops
.
keys
(),
torch_exclude_ops
class
TensorOps
(
PyTorchOperation
):
"""
corresponding to _get_tensor_ops in torch.jit.supported_ops
"""
_ori_type_name
,
_op_args
=
_get_tensor_ops
()
comparison_ops
=
{
'aten::eq'
:
'=='
,
'aten::ne'
:
'!='
,
'aten::le'
:
'<='
,
'aten::ge'
:
'>='
,
'aten::lt'
:
'<'
,
'aten::gt'
:
'>'
}
@
staticmethod
def
_get_matched_args
(
_type
,
inputs
):
def
has_same_arg_name
(
matched
):
concated_names
=
[]
for
i
,
each
in
enumerate
(
matched
):
name
=
','
.
join
([
arg
[
0
]
for
arg
in
each
])
concated_names
.
append
(
name
)
for
i
in
range
(
len
(
concated_names
)
-
1
):
if
concated_names
[
i
]
!=
concated_names
[
i
+
1
]:
return
False
return
True
overloaded_defs
=
TensorOps
.
_op_args
[
_type
]
matched
=
[]
for
each
in
overloaded_defs
:
# plus 1 because we skip the first argument when generating tensor op def
if
len
(
each
)
+
1
==
len
(
inputs
):
matched
.
append
(
each
)
if
len
(
matched
)
==
1
:
return
matched
[
0
]
elif
len
(
matched
)
>
1
:
# TODO: match with arg's type. manually choose for now
if
has_same_arg_name
(
matched
):
# return any one is okay
return
matched
[
0
]
elif
_type
in
ManuallyChooseDef
:
return
ManuallyChooseDef
[
_type
]
else
:
raise
RuntimeError
(
f
'tensor op type
{
_type
}
has more than one matched:
{
matched
}
'
)
else
:
if
_type
in
TensorOpExceptions
:
return
None
raise
RuntimeError
(
f
'tensor op type
{
_type
}
has no matched'
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
# TODO: deal with conditional ops
if
self
.
type
in
TensorOps
.
comparison_ops
:
return
f
'
{
output
}
= (
{
inputs
[
0
]
}
{
TensorOps
.
comparison_ops
[
self
.
type
]
}
{
inputs
[
1
]
}
)'
matched_args
=
TensorOps
.
_get_matched_args
(
self
.
type
,
inputs
)
if
matched_args
is
None
:
return
TensorOpExceptions
[
self
.
type
](
output
,
inputs
)
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
+
1
]
}
'
for
i
,
(
name
,
t
,
default
)
in
enumerate
(
matched_args
)])
print
(
args_str
)
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
op_name
}
(
{
args_str
}
)'
class
TorchOps
(
PyTorchOperation
):
"""
corresponding to _get_nn_functional_ops in torch.jit.supported_ops
"""
_ori_type_name
,
_op_args
=
_get_torch_ops_exclude_tensor_ops
()
# add 'aten::pixel_shuffle'
_op_args
[
'aten::pixel_shuffle'
]
=
[[(
'input'
,
'Tensor'
,
'None'
),
(
'upscale_factor'
,
'Optional[int]'
,
'None'
)]]
_ori_type_name
=
_op_args
.
keys
()
@
staticmethod
def
_get_matched_args
(
_type
,
inputs
):
def
has_same_arg_name
(
matched
):
concated_names
=
[]
for
i
,
each
in
enumerate
(
matched
):
name
=
','
.
join
([
arg
[
0
]
for
arg
in
each
])
concated_names
.
append
(
name
)
for
i
in
range
(
len
(
concated_names
)
-
1
):
if
concated_names
[
i
]
!=
concated_names
[
i
+
1
]:
return
False
return
True
overloaded_defs
=
TorchOps
.
_op_args
[
_type
]
matched
=
[]
for
each
in
overloaded_defs
:
if
len
(
each
)
==
len
(
inputs
):
matched
.
append
(
each
)
if
len
(
matched
)
==
1
:
return
matched
[
0
]
elif
len
(
matched
)
>
1
:
# TODO: match with arg's type. manually choose for now
if
has_same_arg_name
(
matched
):
# return any one is okay
return
matched
[
0
]
else
:
raise
RuntimeError
(
f
'torch op type
{
_type
}
has more than one matched:
{
matched
}
'
)
else
:
raise
RuntimeError
(
f
'torch op type
{
_type
}
has no matched'
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
matched_args
=
TorchOps
.
_get_matched_args
(
self
.
type
,
inputs
)
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
]
}
'
if
t
.
startswith
(
'Optional['
)
else
f
'
{
inputs
[
i
]
}
'
\
for
i
,
(
name
,
t
,
default
)
in
enumerate
(
matched_args
)])
return
f
'
{
output
}
= torch.
{
op_name
}
(
{
args_str
}
)'
class
AtenAvgpool2d
(
PyTorchOperation
):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name
=
[
'aten::avg_pool2d'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= F.avg_pool2d(
{
", "
.
join
(
inputs
)
}
)'
\ No newline at end of file
nni/retiarii/utils.py
View file @
58d5c2fa
...
@@ -162,6 +162,14 @@ def _get_module_name(cls):
...
@@ -162,6 +162,14 @@ def _get_module_name(cls):
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
module_name
=
main_file_path
.
stem
module_name
=
main_file_path
.
stem
break
break
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if
f
'
{
cls
.
__module__
}
.
{
cls
.
__name__
}
'
==
'torch.nn.modules.rnn.LSTM'
:
module_name
=
cls
.
__module__
return
module_name
return
module_name
...
...
pipelines/fast-test.yml
View file @
58d5c2fa
...
@@ -250,7 +250,9 @@ stages:
...
@@ -250,7 +250,9 @@ stages:
-
script
:
|
-
script
:
|
cd test
cd test
python -m pytest ut
python -m pytest ut --ignore=ut/retiarii/test_convert_basic.py \
--ignore=ut/retiarii/test_convert_operators.py \
--ignore=ut/retiarii/test_convert_pytorch.py
displayName
:
Python unit test
displayName
:
Python unit test
-
script
:
|
-
script
:
|
...
...
test/ut/retiarii/inject_nn.py
0 → 100644
View file @
58d5c2fa
import
inspect
import
logging
import
torch
import
torch.nn
as
nn
from
nni.retiarii.utils
import
add_record
,
del_record
,
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
def
wrap_module
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
original_class
.
bak_init_for_inject
=
orig_init
if
hasattr
(
original_class
,
'__del__'
):
orig_del
=
original_class
.
__del__
original_class
.
bak_del_for_inject
=
orig_del
else
:
orig_del
=
None
original_class
.
bak_del_for_inject
=
None
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
def
__del__
(
self
):
del_record
(
id
(
self
))
if
orig_del
is
not
None
:
orig_del
(
self
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__del__
=
__del__
return
original_class
def
unwrap_module
(
wrapped_class
):
if
hasattr
(
wrapped_class
,
'bak_init_for_inject'
):
wrapped_class
.
__init__
=
wrapped_class
.
bak_init_for_inject
delattr
(
wrapped_class
,
'bak_init_for_inject'
)
if
hasattr
(
wrapped_class
,
'bak_del_for_inject'
):
if
wrapped_class
.
bak_del_for_inject
is
not
None
:
wrapped_class
.
__del__
=
wrapped_class
.
bak_del_for_inject
delattr
(
wrapped_class
,
'bak_del_for_inject'
)
return
None
def
remove_inject_pytorch_nn
():
Identity
=
unwrap_module
(
nn
.
Identity
)
Linear
=
unwrap_module
(
nn
.
Linear
)
Conv1d
=
unwrap_module
(
nn
.
Conv1d
)
Conv2d
=
unwrap_module
(
nn
.
Conv2d
)
Conv3d
=
unwrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
unwrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
unwrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
unwrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
unwrap_module
(
nn
.
Threshold
)
ReLU
=
unwrap_module
(
nn
.
ReLU
)
Hardtanh
=
unwrap_module
(
nn
.
Hardtanh
)
ReLU6
=
unwrap_module
(
nn
.
ReLU6
)
Sigmoid
=
unwrap_module
(
nn
.
Sigmoid
)
Tanh
=
unwrap_module
(
nn
.
Tanh
)
Softmax
=
unwrap_module
(
nn
.
Softmax
)
Softmax2d
=
unwrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
unwrap_module
(
nn
.
LogSoftmax
)
ELU
=
unwrap_module
(
nn
.
ELU
)
SELU
=
unwrap_module
(
nn
.
SELU
)
CELU
=
unwrap_module
(
nn
.
CELU
)
GLU
=
unwrap_module
(
nn
.
GLU
)
GELU
=
unwrap_module
(
nn
.
GELU
)
Hardshrink
=
unwrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
unwrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
unwrap_module
(
nn
.
LogSigmoid
)
Softplus
=
unwrap_module
(
nn
.
Softplus
)
Softshrink
=
unwrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
unwrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
unwrap_module
(
nn
.
PReLU
)
Softsign
=
unwrap_module
(
nn
.
Softsign
)
Softmin
=
unwrap_module
(
nn
.
Softmin
)
Tanhshrink
=
unwrap_module
(
nn
.
Tanhshrink
)
RReLU
=
unwrap_module
(
nn
.
RReLU
)
AvgPool1d
=
unwrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
unwrap_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
unwrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
unwrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
unwrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
unwrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
unwrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
unwrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
unwrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
unwrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
unwrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
unwrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
unwrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
unwrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
unwrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
unwrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
unwrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
unwrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
unwrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
unwrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
unwrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
unwrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
unwrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
unwrap_module
(
nn
.
Dropout
)
Dropout2d
=
unwrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
unwrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
unwrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
unwrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
unwrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
unwrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
unwrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
unwrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
unwrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
unwrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
unwrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
unwrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
unwrap_module
(
nn
.
RNNBase
)
RNN
=
unwrap_module
(
nn
.
RNN
)
LSTM
=
unwrap_module
(
nn
.
LSTM
)
GRU
=
unwrap_module
(
nn
.
GRU
)
RNNCellBase
=
unwrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
unwrap_module
(
nn
.
RNNCell
)
LSTMCell
=
unwrap_module
(
nn
.
LSTMCell
)
GRUCell
=
unwrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
unwrap_module
(
nn
.
PixelShuffle
)
Upsample
=
unwrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
unwrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
unwrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
unwrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
unwrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
unwrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
unwrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
unwrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
unwrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
unwrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
unwrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
unwrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
unwrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
unwrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
unwrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
unwrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
unwrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
unwrap_module
(
nn
.
Unfold
)
Fold
=
unwrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
unwrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
unwrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
unwrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
unwrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
unwrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
unwrap_module
(
nn
.
Transformer
)
Flatten
=
unwrap_module
(
nn
.
Flatten
)
Hardsigmoid
=
unwrap_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
unwrap_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
unwrap_module
(
nn
.
SiLU
)
Unflatten
=
unwrap_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
unwrap_module
(
nn
.
TripletMarginWithDistanceLoss
)
def
inject_pytorch_nn
():
Identity
=
wrap_module
(
nn
.
Identity
)
Linear
=
wrap_module
(
nn
.
Linear
)
Conv1d
=
wrap_module
(
nn
.
Conv1d
)
Conv2d
=
wrap_module
(
nn
.
Conv2d
)
Conv3d
=
wrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
wrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
wrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
wrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
wrap_module
(
nn
.
Threshold
)
ReLU
=
wrap_module
(
nn
.
ReLU
)
Hardtanh
=
wrap_module
(
nn
.
Hardtanh
)
ReLU6
=
wrap_module
(
nn
.
ReLU6
)
Sigmoid
=
wrap_module
(
nn
.
Sigmoid
)
Tanh
=
wrap_module
(
nn
.
Tanh
)
Softmax
=
wrap_module
(
nn
.
Softmax
)
Softmax2d
=
wrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
wrap_module
(
nn
.
LogSoftmax
)
ELU
=
wrap_module
(
nn
.
ELU
)
SELU
=
wrap_module
(
nn
.
SELU
)
CELU
=
wrap_module
(
nn
.
CELU
)
GLU
=
wrap_module
(
nn
.
GLU
)
GELU
=
wrap_module
(
nn
.
GELU
)
Hardshrink
=
wrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
wrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
wrap_module
(
nn
.
LogSigmoid
)
Softplus
=
wrap_module
(
nn
.
Softplus
)
Softshrink
=
wrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
wrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
wrap_module
(
nn
.
PReLU
)
Softsign
=
wrap_module
(
nn
.
Softsign
)
Softmin
=
wrap_module
(
nn
.
Softmin
)
Tanhshrink
=
wrap_module
(
nn
.
Tanhshrink
)
RReLU
=
wrap_module
(
nn
.
RReLU
)
AvgPool1d
=
wrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
wrap_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
wrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
wrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
wrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
wrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
wrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
wrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
wrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
wrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
wrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
wrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
wrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
wrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
wrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
wrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
wrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
wrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
wrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
wrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
wrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
wrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
wrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
wrap_module
(
nn
.
Dropout
)
Dropout2d
=
wrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
wrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
wrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
wrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
wrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
wrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
wrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
wrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
wrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
wrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
wrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
wrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
wrap_module
(
nn
.
RNNBase
)
RNN
=
wrap_module
(
nn
.
RNN
)
LSTM
=
wrap_module
(
nn
.
LSTM
)
GRU
=
wrap_module
(
nn
.
GRU
)
RNNCellBase
=
wrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
wrap_module
(
nn
.
RNNCell
)
LSTMCell
=
wrap_module
(
nn
.
LSTMCell
)
GRUCell
=
wrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
wrap_module
(
nn
.
PixelShuffle
)
Upsample
=
wrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
wrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
wrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
wrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
wrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
wrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
wrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
wrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
wrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
wrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
wrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
wrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
wrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
wrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
wrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
wrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
wrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
wrap_module
(
nn
.
Unfold
)
Fold
=
wrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
wrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
wrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
wrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
wrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
wrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
wrap_module
(
nn
.
Transformer
)
Flatten
=
wrap_module
(
nn
.
Flatten
)
Hardsigmoid
=
wrap_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
wrap_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
wrap_module
(
nn
.
SiLU
)
Unflatten
=
wrap_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
wrap_module
(
nn
.
TripletMarginWithDistanceLoss
)
test/ut/retiarii/test_convert.py
View file @
58d5c2fa
...
@@ -35,16 +35,29 @@ class MnistNet(nn.Module):
...
@@ -35,16 +35,29 @@ class MnistNet(nn.Module):
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
# NOTE: blackbox module cannot be placed within class or function
@
blackbox_module
class
Linear
(
nn
.
Module
):
def
__init__
(
self
,
d_embed
,
d_proj
):
super
().
__init__
()
self
.
linear
=
nn
.
Linear
(
d_embed
,
d_proj
)
def
forward
(
self
,
input
):
if
len
(
input
.
size
())
<=
2
:
return
self
.
linear
(
input
)
size
=
input
.
size
()[:
2
]
out
=
self
.
linear
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
TestConvert
(
unittest
.
TestCase
):
class
TestConvert
(
unittest
.
TestCase
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
k
,
v
in
expected_format
.
items
():
for
cv
in
current_values
:
for
idx
,
cv
in
enumerate
(
current_values
)
:
if
cv
.
shape
==
v
.
shape
:
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
result
[
k
]
=
cv
current_values
.
remove
(
cv
)
current_values
.
pop
(
idx
)
break
break
return
result
return
result
...
@@ -53,6 +66,9 @@ class TestConvert(unittest.TestCase):
...
@@ -53,6 +66,9 @@ class TestConvert(unittest.TestCase):
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
from
.inject_nn
import
remove_inject_pytorch_nn
remove_inject_pytorch_nn
()
exec_vars
=
{}
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_model
=
exec_vars
[
'converted_model'
]
...
@@ -134,18 +150,17 @@ class TestConvert(unittest.TestCase):
...
@@ -134,18 +150,17 @@ class TestConvert(unittest.TestCase):
model
=
DCGANGenerator
(
nz
,
ngf
,
nc
)
model
=
DCGANGenerator
(
nz
,
ngf
,
nc
)
self
.
checkExportImport
(
model
,
input
)
self
.
checkExportImport
(
model
,
input
)
@
unittest
.
skip
(
'this test has a if condition that needs to be handle'
)
# FIXME
def
test_neural_style
(
self
):
def
test_neural_style
(
self
):
class
TransformerNet
(
torch
.
nn
.
Module
):
class
TransformerNet
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
TransformerNet
,
self
).
__init__
()
super
(
TransformerNet
,
self
).
__init__
()
# Initial convolution layers
# Initial convolution layers
self
.
conv1
=
ConvLayer
(
3
,
32
,
kernel_size
=
9
,
stride
=
1
)
self
.
conv1
=
ConvLayer
(
3
,
32
,
kernel_size
=
9
,
stride
=
1
)
self
.
in1
=
torch
.
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
in1
=
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
conv2
=
ConvLayer
(
32
,
64
,
kernel_size
=
3
,
stride
=
2
)
self
.
conv2
=
ConvLayer
(
32
,
64
,
kernel_size
=
3
,
stride
=
2
)
self
.
in2
=
torch
.
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
in2
=
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
conv3
=
ConvLayer
(
64
,
128
,
kernel_size
=
3
,
stride
=
2
)
self
.
conv3
=
ConvLayer
(
64
,
128
,
kernel_size
=
3
,
stride
=
2
)
self
.
in3
=
torch
.
nn
.
InstanceNorm2d
(
128
,
affine
=
True
)
self
.
in3
=
nn
.
InstanceNorm2d
(
128
,
affine
=
True
)
# Residual layers
# Residual layers
self
.
res1
=
ResidualBlock
(
128
)
self
.
res1
=
ResidualBlock
(
128
)
self
.
res2
=
ResidualBlock
(
128
)
self
.
res2
=
ResidualBlock
(
128
)
...
@@ -154,12 +169,12 @@ class TestConvert(unittest.TestCase):
...
@@ -154,12 +169,12 @@ class TestConvert(unittest.TestCase):
self
.
res5
=
ResidualBlock
(
128
)
self
.
res5
=
ResidualBlock
(
128
)
# Upsampling Layers
# Upsampling Layers
self
.
deconv1
=
UpsampleConvLayer
(
128
,
64
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
deconv1
=
UpsampleConvLayer
(
128
,
64
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
in4
=
torch
.
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
in4
=
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
deconv2
=
UpsampleConvLayer
(
64
,
32
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
deconv2
=
UpsampleConvLayer
(
64
,
32
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
in5
=
torch
.
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
in5
=
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
deconv3
=
ConvLayer
(
32
,
3
,
kernel_size
=
9
,
stride
=
1
)
self
.
deconv3
=
ConvLayer
(
32
,
3
,
kernel_size
=
9
,
stride
=
1
)
# Non-linearities
# Non-linearities
self
.
relu
=
torch
.
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
X
):
def
forward
(
self
,
X
):
y
=
self
.
relu
(
self
.
in1
(
self
.
conv1
(
X
)))
y
=
self
.
relu
(
self
.
in1
(
self
.
conv1
(
X
)))
...
@@ -175,19 +190,19 @@ class TestConvert(unittest.TestCase):
...
@@ -175,19 +190,19 @@ class TestConvert(unittest.TestCase):
y
=
self
.
deconv3
(
y
)
y
=
self
.
deconv3
(
y
)
return
y
return
y
class
ConvLayer
(
torch
.
nn
.
Module
):
class
ConvLayer
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
):
super
(
ConvLayer
,
self
).
__init__
()
super
(
ConvLayer
,
self
).
__init__
()
reflection_padding
=
kernel_size
//
2
reflection_padding
=
kernel_size
//
2
self
.
reflection_pad
=
torch
.
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
reflection_pad
=
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
self
.
conv2d
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
out
=
self
.
reflection_pad
(
x
)
out
=
self
.
reflection_pad
(
x
)
out
=
self
.
conv2d
(
out
)
out
=
self
.
conv2d
(
out
)
return
out
return
out
class
ResidualBlock
(
torch
.
nn
.
Module
):
class
ResidualBlock
(
nn
.
Module
):
"""ResidualBlock
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
...
@@ -196,10 +211,10 @@ class TestConvert(unittest.TestCase):
...
@@ -196,10 +211,10 @@ class TestConvert(unittest.TestCase):
def
__init__
(
self
,
channels
):
def
__init__
(
self
,
channels
):
super
(
ResidualBlock
,
self
).
__init__
()
super
(
ResidualBlock
,
self
).
__init__
()
self
.
conv1
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
conv1
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
in1
=
torch
.
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
in1
=
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
conv2
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
conv2
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
in2
=
torch
.
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
in2
=
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
relu
=
torch
.
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
residual
=
x
residual
=
x
...
@@ -208,7 +223,7 @@ class TestConvert(unittest.TestCase):
...
@@ -208,7 +223,7 @@ class TestConvert(unittest.TestCase):
out
=
out
+
residual
out
=
out
+
residual
return
out
return
out
class
UpsampleConvLayer
(
torch
.
nn
.
Module
):
class
UpsampleConvLayer
(
nn
.
Module
):
"""UpsampleConvLayer
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
compared to ConvTranspose2d.
...
@@ -219,10 +234,10 @@ class TestConvert(unittest.TestCase):
...
@@ -219,10 +234,10 @@ class TestConvert(unittest.TestCase):
super
(
UpsampleConvLayer
,
self
).
__init__
()
super
(
UpsampleConvLayer
,
self
).
__init__
()
self
.
upsample
=
upsample
self
.
upsample
=
upsample
if
upsample
:
if
upsample
:
self
.
upsample_layer
=
torch
.
nn
.
Upsample
(
mode
=
'nearest'
,
scale_factor
=
upsample
)
self
.
upsample_layer
=
nn
.
Upsample
(
mode
=
'nearest'
,
scale_factor
=
upsample
)
reflection_padding
=
kernel_size
//
2
reflection_padding
=
kernel_size
//
2
self
.
reflection_pad
=
torch
.
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
reflection_pad
=
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
self
.
conv2d
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x_in
=
x
x_in
=
x
...
@@ -254,50 +269,40 @@ class TestConvert(unittest.TestCase):
...
@@ -254,50 +269,40 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
Policy
(),
(
torch
.
rand
(
1
,
4
),))
self
.
checkExportImport
(
Policy
(),
(
torch
.
rand
(
1
,
4
),))
@
unittest
.
skip
(
'Replaced init error.'
)
# FIXME
def
test_snli
(
self
):
def
test_snli
(
self
):
class
Bottle
(
nn
.
Module
):
def
forward
(
self
,
input
):
if
len
(
input
.
size
())
<=
2
:
return
super
(
Bottle
,
self
).
forward
(
input
)
size
=
input
.
size
()[:
2
]
out
=
super
(
Bottle
,
self
).
forward
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
Linear
(
Bottle
,
nn
.
Linear
):
pass
class
Encoder
(
nn
.
Module
):
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
Encoder
,
self
).
__init__
()
super
(
Encoder
,
self
).
__init__
()
self
.
config
=
config
#self.config = config
input_size
=
config
.
d_proj
if
config
.
projection
else
config
.
d_embed
input_size
=
config
[
"d_proj"
]
if
config
[
"projection"
]
else
config
[
"d_embed"
]
dropout
=
0
if
config
.
n_layers
==
1
else
config
.
dp_ratio
dropout
=
0
if
config
[
"n_layers"
]
==
1
else
config
[
"dp_ratio"
]
self
.
rnn
=
nn
.
LSTM
(
input_size
=
input_size
,
hidden_size
=
config
.
d_hidden
,
self
.
rnn
=
nn
.
LSTM
(
input_size
=
input_size
,
hidden_size
=
config
[
"d_hidden"
],
num_layers
=
config
.
n_layers
,
dropout
=
dropout
,
num_layers
=
config
[
"n_layers"
],
dropout
=
dropout
,
bidirectional
=
config
.
birnn
)
bidirectional
=
config
[
"birnn"
])
self
.
n_cells
=
config
[
"n_cells"
]
self
.
d_hidden
=
config
[
"d_hidden"
]
self
.
birnn
=
config
[
"birnn"
]
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
batch_size
=
inputs
.
size
()[
1
]
batch_size
=
inputs
.
size
()[
1
]
state_shape
=
self
.
config
.
n_cells
,
batch_size
,
self
.
config
.
d_hidden
state_shape
=
self
.
n_cells
,
batch_size
,
self
.
d_hidden
h0
=
c0
=
inputs
.
new_zeros
(
state_shape
)
h0
=
c0
=
inputs
.
new_zeros
(
state_shape
)
outputs
,
(
ht
,
ct
)
=
self
.
rnn
(
inputs
,
(
h0
,
c0
))
outputs
,
(
ht
,
ct
)
=
self
.
rnn
(
inputs
,
(
h0
,
c0
))
return
ht
[
-
1
]
if
not
self
.
config
.
birnn
else
ht
[
-
2
:].
transpose
(
0
,
1
).
contiguous
().
view
(
batch_size
,
-
1
)
return
ht
[
-
1
]
if
not
self
.
birnn
else
ht
[
-
2
:].
transpose
(
0
,
1
).
contiguous
().
view
(
batch_size
,
-
1
)
class
SNLIClassifier
(
nn
.
Module
):
class
SNLIClassifier
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
SNLIClassifier
,
self
).
__init__
()
super
(
SNLIClassifier
,
self
).
__init__
()
self
.
config
=
config
self
.
embed
=
nn
.
Embedding
(
config
[
"n_embed"
],
config
[
"d_embed"
])
self
.
embed
=
nn
.
Embedding
(
config
.
n_embed
,
config
.
d_embed
)
self
.
projection
=
Linear
(
config
[
"d_embed"
],
config
[
"d_proj"
])
self
.
projection
=
Linear
(
config
.
d_embed
,
config
.
d_proj
)
self
.
encoder
=
Encoder
(
config
)
self
.
encoder
=
Encoder
(
config
)
self
.
dropout
=
nn
.
Dropout
(
p
=
config
.
dp_ratio
)
self
.
dropout
=
nn
.
Dropout
(
p
=
config
[
"
dp_ratio
"
]
)
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
seq_in_size
=
2
*
config
.
d_hidden
seq_in_size
=
2
*
config
[
"
d_hidden
"
]
if
self
.
config
.
birnn
:
if
config
[
"
birnn
"
]
:
seq_in_size
*=
2
seq_in_size
*=
2
lin_config
=
[
seq_in_size
]
*
2
lin_config
=
[
seq_in_size
]
*
2
self
.
out
=
nn
.
Sequential
(
self
.
out
=
nn
.
Sequential
(
...
@@ -310,15 +315,17 @@ class TestConvert(unittest.TestCase):
...
@@ -310,15 +315,17 @@ class TestConvert(unittest.TestCase):
Linear
(
*
lin_config
),
Linear
(
*
lin_config
),
self
.
relu
,
self
.
relu
,
self
.
dropout
,
self
.
dropout
,
Linear
(
seq_in_size
,
config
.
d_out
))
Linear
(
seq_in_size
,
config
[
"d_out"
]))
self
.
fix_emb
=
config
[
"fix_emb"
]
self
.
project
=
config
[
"projection"
]
def
forward
(
self
,
premise
,
hypothesis
):
def
forward
(
self
,
premise
,
hypothesis
):
prem_embed
=
self
.
embed
(
premise
)
prem_embed
=
self
.
embed
(
premise
)
hypo_embed
=
self
.
embed
(
hypothesis
)
hypo_embed
=
self
.
embed
(
hypothesis
)
if
self
.
config
.
fix_emb
:
if
self
.
fix_emb
:
prem_embed
=
prem_embed
.
detach
()
prem_embed
=
prem_embed
.
detach
()
hypo_embed
=
hypo_embed
.
detach
()
hypo_embed
=
hypo_embed
.
detach
()
if
self
.
config
.
project
ion
:
if
self
.
project
:
prem_embed
=
self
.
relu
(
self
.
projection
(
prem_embed
))
prem_embed
=
self
.
relu
(
self
.
projection
(
prem_embed
))
hypo_embed
=
self
.
relu
(
self
.
projection
(
hypo_embed
))
hypo_embed
=
self
.
relu
(
self
.
projection
(
hypo_embed
))
premise
=
self
.
encoder
(
prem_embed
)
premise
=
self
.
encoder
(
prem_embed
)
...
@@ -326,23 +333,24 @@ class TestConvert(unittest.TestCase):
...
@@ -326,23 +333,24 @@ class TestConvert(unittest.TestCase):
scores
=
self
.
out
(
torch
.
cat
([
premise
,
hypothesis
],
1
))
scores
=
self
.
out
(
torch
.
cat
([
premise
,
hypothesis
],
1
))
return
scores
return
scores
class
Config
:
Config
=
{
n_embed
=
100
"n_embed"
:
100
,
d_embed
=
100
"d_embed"
:
100
,
d_proj
=
300
"d_proj"
:
300
,
dp_ratio
=
0.0
# For deterministic testing TODO: change by fixing seed in checkTrace?
"dp_ratio"
:
0.0
,
# For deterministic testing TOD": change by fixing seed in checkTrace?,
d_hidden
=
30
"d_hidden"
:
30
,
birnn
=
True
"birnn"
:
True
,
d_out
=
300
"d_out"
:
300
,
fix_emb
=
True
"fix_emb"
:
True
,
projection
=
True
"projection"
:
True
,
n_layers
=
2
"n_layers"
:
2
,
n_cells
=
4
# 2 * n_layers because birnn = True
"n_cells"
:
4
# 2 * n_layers because birnn = True,
}
premise
=
torch
.
LongTensor
(
48
,
64
).
random_
(
0
,
100
)
premise
=
torch
.
LongTensor
(
48
,
64
).
random_
(
0
,
100
)
hypothesis
=
torch
.
LongTensor
(
24
,
64
).
random_
(
0
,
100
)
hypothesis
=
torch
.
LongTensor
(
24
,
64
).
random_
(
0
,
100
)
self
.
checkExportImport
(
SNLIClassifier
(
Config
()
),
(
premise
,
hypothesis
))
self
.
checkExportImport
(
SNLIClassifier
(
Config
),
(
premise
,
hypothesis
))
def
test_super_resolution
(
self
):
def
test_super_resolution
(
self
):
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
...
@@ -367,16 +375,16 @@ class TestConvert(unittest.TestCase):
...
@@ -367,16 +375,16 @@ class TestConvert(unittest.TestCase):
net
=
Net
(
upscale_factor
=
4
)
net
=
Net
(
upscale_factor
=
4
)
self
.
checkExportImport
(
net
,
(
torch
.
rand
(
5
,
1
,
32
,
32
),))
self
.
checkExportImport
(
net
,
(
torch
.
rand
(
5
,
1
,
32
,
32
),))
@
unittest
.
skip
(
'Need to support op
erator prim::ListUnpack
'
)
# FIXME
@
unittest
.
skip
(
'Need to support
Lo
op'
)
# FIXME
def
test_time_sequence_prediction
(
self
):
def
test_time_sequence_prediction
(
self
):
class
Sequence
(
torch
.
jit
.
ScriptModule
):
class
Sequence
(
nn
.
Module
):
#
torch.jit.ScriptModule
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Sequence
,
self
).
__init__
()
super
(
Sequence
,
self
).
__init__
()
self
.
lstm1
=
nn
.
LSTMCell
(
1
,
51
)
self
.
lstm1
=
nn
.
LSTMCell
(
1
,
51
)
self
.
lstm2
=
nn
.
LSTMCell
(
51
,
51
)
self
.
lstm2
=
nn
.
LSTMCell
(
51
,
51
)
self
.
linear
=
nn
.
Linear
(
51
,
1
)
self
.
linear
=
nn
.
Linear
(
51
,
1
)
@
torch
.
jit
.
script_method
#
@torch.jit.script_method
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
# TODO: add future as input with default val
# TODO: add future as input with default val
# see https://github.com/pytorch/pytorch/issues/8724
# see https://github.com/pytorch/pytorch/issues/8724
...
@@ -414,7 +422,7 @@ class TestConvert(unittest.TestCase):
...
@@ -414,7 +422,7 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
Traced
(),
(
torch
.
rand
(
3
,
4
),))
self
.
checkExportImport
(
Traced
(),
(
torch
.
rand
(
3
,
4
),))
@
unittest
.
skip
(
'
Unsupported callmethod encode
'
)
# FIXME
@
unittest
.
skip
(
'
incorrectly assigned weights
'
)
# FIXME
def
test_vae
(
self
):
def
test_vae
(
self
):
class
VAE
(
nn
.
Module
):
class
VAE
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -449,11 +457,11 @@ class TestConvert(unittest.TestCase):
...
@@ -449,11 +457,11 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
VAE
().
eval
(),
(
torch
.
rand
(
128
,
1
,
28
,
28
),))
self
.
checkExportImport
(
VAE
().
eval
(),
(
torch
.
rand
(
128
,
1
,
28
,
28
),))
@
unittest
.
skip
(
'torchvision models are not supported yet'
)
# FIXME
def
test_torchvision_resnet18
(
self
):
def
test_torchvision_resnet18
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
@
unittest
.
skip
(
'Unsupported CallMethod _forward_impl'
)
# FIXME
def
test_resnet
(
self
):
def
test_resnet
(
self
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
"""1x1 convolution"""
"""1x1 convolution"""
...
@@ -464,7 +472,7 @@ class TestConvert(unittest.TestCase):
...
@@ -464,7 +472,7 @@ class TestConvert(unittest.TestCase):
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
False
)
padding
=
1
,
bias
=
False
)
class
BasicBlock
(
torch
.
jit
.
ScriptModule
):
class
BasicBlock
(
nn
.
Module
):
#
torch.jit.ScriptModule
expansion
=
1
expansion
=
1
__constants__
=
[
'downsample'
]
__constants__
=
[
'downsample'
]
...
@@ -478,7 +486,8 @@ class TestConvert(unittest.TestCase):
...
@@ -478,7 +486,8 @@ class TestConvert(unittest.TestCase):
self
.
downsample
=
downsample
self
.
downsample
=
downsample
self
.
stride
=
stride
self
.
stride
=
stride
@
torch
.
jit
.
script_method
# NOTE: jit cannot be annotated, otherwise, module id is not matched for recorded arguments
#@torch.jit.script_method
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
residual
=
x
residual
=
x
...
@@ -497,7 +506,8 @@ class TestConvert(unittest.TestCase):
...
@@ -497,7 +506,8 @@ class TestConvert(unittest.TestCase):
return
out
return
out
class
ResNet
(
torch
.
jit
.
ScriptModule
):
# NOTE: cannot inherit torch.jit.ScriptModule, otherwise, there would be error: 'RecursiveScriptModule' object has no attribute 'graph'
class
ResNet
(
nn
.
Module
):
#torch.jit.ScriptModule
__constants__
=
[
'layer1'
,
'layer2'
,
'layer3'
,
'layer4'
]
__constants__
=
[
'layer1'
,
'layer2'
,
'layer3'
,
'layer4'
]
def
__init__
(
self
,
block
,
layers
,
num_classes
=
1000
):
def
__init__
(
self
,
block
,
layers
,
num_classes
=
1000
):
...
@@ -538,7 +548,8 @@ class TestConvert(unittest.TestCase):
...
@@ -538,7 +548,8 @@ class TestConvert(unittest.TestCase):
return
nn
.
Sequential
(
*
layers
)
return
nn
.
Sequential
(
*
layers
)
@
torch
.
jit
.
script_method
# NOTE: jit cannot be annotated, otherwise, module id is not matched for recorded arguments
#@torch.jit.script_method
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
bn1
(
x
)
...
@@ -558,10 +569,11 @@ class TestConvert(unittest.TestCase):
...
@@ -558,10 +569,11 @@ class TestConvert(unittest.TestCase):
resnet18
=
ResNet
(
BasicBlock
,
[
2
,
2
,
2
,
2
])
resnet18
=
ResNet
(
BasicBlock
,
[
2
,
2
,
2
,
2
])
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
()
,
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
self
.
checkExportImport
(
resnet18
,
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
@
unittest
.
skip
(
'torchvision models are not supported yet'
)
# FIXME
def
test_alexnet
(
self
):
def
test_alexnet
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
model
=
torchvision
.
models
.
AlexNet
()
model
=
torchvision
.
models
.
AlexNet
()
self
.
checkExportImport
(
model
,
(
x
,))
self
.
checkExportImport
(
model
,
(
x
,))
test/ut/retiarii/test_convert_basic.py
0 → 100644
View file @
58d5c2fa
import
os
import
sys
import
unittest
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
# following pytorch v1.7.1
class
TestConvert
(
unittest
.
TestCase
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
if
check_value
:
self
.
assertEqual
(
len
(
converted_output
),
len
(
expected_output
))
for
a
,
b
in
zip
(
converted_output
,
expected_output
):
if
hasattr
(
a
,
'dtype'
)
and
a
.
dtype
==
torch
.
bool
:
self
.
assertEqual
((
a
^
b
),
False
)
elif
isinstance
((
a
-
b
),
int
):
self
.
assertEqual
((
a
-
b
),
0
)
else
:
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
return
converted_model
# skip torch.Tensor.new_tensor as it is not supported by jit
def
test_basic_new_full
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
# requires_grad is not supported by jit
# aten::new_full(Tensor self, int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor):
# Keyword argument requires_grad unknown.
out
=
x
.
new_full
((
3
,
4
),
3.141592
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
'cpu'
))
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
ones
((
2
,),
dtype
=
torch
.
float64
),
))
def
test_basic_new_empty
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
new_empty
((
2
,
3
),
dtype
=
torch
.
int8
,
device
=
torch
.
device
(
'cpu'
))
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
ones
(()),
),
check_value
=
False
)
# skip torch.Tensor.new_ones as it is not supported by jit
# requires_grad=False is not supported by jit
def
test_basic_new_zeros
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
new_zeros
((
2
,
3
))
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
((),
dtype
=
torch
.
int32
),
))
def
test_basic_is_cuda
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
tensor
([
x
.
is_cuda
],
dtype
=
torch
.
bool
,
device
=
torch
.
device
(
'cpu'
))
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
((),
dtype
=
torch
.
int32
),
))
# is_quantized
# is_meta
# device
# grad
# ndim
# T
# real
# imag
def
test_basic_abs
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
abs
()
out11
=
x
.
absolute
()
out2
=
torch
.
abs
(
x
)
#out3 = x.abs_()
#out33 = x.absolute_()
return
out1
,
out11
,
out2
#, out3, out33
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1
,
-
2
,
3
]),
))
# TODO: topological sort should be improved
#def forward(self, x__1):
# __Acos2 = x__1.acos()
# __Acos_3 = x__1.acos_()
# __Acos1 = x__1.acos()
# __TupleConstruct4 = (__Acos1,__Acos2,__Acos_3)
# return __TupleConstruct4
def
test_basic_acos_asin_atan
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
acos
()
out2
=
torch
.
acos
(
x
)
# TODO: add back this line
#out = x.acos_()
out3
=
x
.
asin
()
out4
=
torch
.
asin
(
x
)
out5
=
x
.
atan
()
out6
=
torch
.
atan
(
x
)
out7
=
x
.
atan2
(
y
)
out8
=
torch
.
atan2
(
x
,
y
)
return
out1
,
out2
,
out3
,
out4
,
out5
,
out6
,
out7
,
out8
#, out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1.0
,
-
0.5
,
0.2
]),
torch
.
tensor
([
1.0
,
0.6
,
-
0.3
]),
))
# arccos is not supported by jit
def
test_basic_add
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
t
=
torch
.
tensor
([
-
1.0
,
-
0.5
,
0.2
])
out1
=
x
.
add
(
t
)
out2
=
x
.
add
(
t
,
alpha
=
2
)
#out3 = x.add_(t)
return
out1
,
out2
#, out3
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1.0
,
-
0.5
,
0.2
]),
))
def
test_basic_addbmm
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
,
m
):
out1
=
x
.
addbmm
(
y
,
z
,
beta
=
2
,
alpha
=
3
)
out2
=
torch
.
addbmm
(
x
,
y
,
z
,
beta
=
2
,
alpha
=
3
)
#out3 = x.addbmm_(y, z, beta=2, alpha=3)
out3
=
m
.
baddbmm
(
y
,
z
,
beta
=
2
,
alpha
=
3
)
out4
=
torch
.
baddbmm
(
m
,
y
,
z
,
beta
=
2
,
alpha
=
3
)
out5
=
torch
.
bmm
(
y
,
z
)
# deterministic is not supported by jit
return
out1
,
out2
,
out3
,
out4
,
out5
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
3
,
5
),
torch
.
randn
(
10
,
3
,
4
),
torch
.
randn
(
10
,
4
,
5
),
torch
.
randn
(
10
,
3
,
5
),
))
def
test_basic_addcdiv
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addcdiv
(
y
,
z
,
value
=
2
)
out2
=
torch
.
addcdiv
(
x
,
y
,
z
,
value
=
2
)
# addcdiv_
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
1
,
3
),
torch
.
randn
(
3
,
1
),
torch
.
randn
(
1
,
3
),
))
def
test_basic_addcmul
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addcmul
(
y
,
z
,
value
=
0.1
)
out2
=
torch
.
addcmul
(
x
,
y
,
z
,
value
=
0.1
)
# addcmul_
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
1
,
3
),
torch
.
randn
(
3
,
1
),
torch
.
randn
(
1
,
3
),
))
def
test_basic_addmm
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addmm
(
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
out2
=
torch
.
addmm
(
x
,
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
# addmm_
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
2
,
3
),
torch
.
randn
(
2
,
3
),
torch
.
randn
(
3
,
3
),
))
def
test_basic_addmv
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addmv
(
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
out2
=
torch
.
addmv
(
x
,
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
2
),
torch
.
randn
(
2
,
3
),
torch
.
randn
(
3
),
))
def
test_basic_addr
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addr
(
y
,
z
,
beta
=
2
,
alpha
=
3
)
out2
=
torch
.
addr
(
x
,
y
,
z
,
beta
=
2
,
alpha
=
3
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
zeros
(
3
,
2
),
torch
.
arange
(
1.
,
4.
),
torch
.
arange
(
1.
,
3.
),
))
def
test_basic_allclose
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
allclose
(
y
,
rtol
=
1e-05
,
atol
=
1e-08
,
equal_nan
=
False
)
out2
=
torch
.
allclose
(
x
,
y
,
rtol
=
1e-05
,
atol
=
1e-08
,
equal_nan
=
False
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
10000.
,
1e-07
]),
torch
.
tensor
([
10000.1
,
1e-08
]),
))
def
test_basic_angle
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
angle
()
out2
=
torch
.
angle
(
x
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1
+
1j
,
-
2
+
2j
,
3
-
3j
]),
))
# skip apply_(callable) for now
def
test_basic_argmax_argmin
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
argmax
()
out2
=
torch
.
argmax
(
x
)
out3
=
x
.
argmax
(
dim
=
1
)
out4
=
torch
.
argmax
(
x
,
dim
=
1
)
out5
=
x
.
argmax
(
dim
=
1
,
keepdim
=
True
)
o1
=
x
.
argmin
()
o2
=
torch
.
argmin
(
x
)
o3
=
x
.
argmin
(
dim
=
1
)
o4
=
x
.
argmin
(
dim
=
1
,
keepdim
=
True
)
return
out1
,
out2
,
out3
,
out4
,
out5
,
o1
,
o2
,
o3
,
o4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
def
test_basic_argsort
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
argsort
()
out2
=
x
.
argsort
(
dim
=
1
)
out3
=
x
.
argsort
(
dim
=
1
,
descending
=
True
)
out4
=
torch
.
argsort
(
x
,
dim
=
1
,
descending
=
True
)
return
out1
,
out2
,
out3
,
out4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
# skip backward(gradient=None, retain_graph=None, create_graph=False)
def
test_basic_bernoulli
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
# generator=torch.Generator() is not supported by jit
out
=
x
.
bernoulli
()
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
ones
(
3
,
3
),
))
# bfloat16/bool/byte/char is not supported by jit
def
test_basic_bincount
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
bincount
()
out2
=
torch
.
bincount
(
x
)
out3
=
x
.
bincount
(
weights
=
y
)
out4
=
x
.
bincount
(
weights
=
y
,
minlength
=
2
)
return
out1
,
out2
,
out3
,
out4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randint
(
0
,
8
,
(
5
,),
dtype
=
torch
.
int64
),
torch
.
linspace
(
0
,
1
,
steps
=
5
),
))
def
test_basic_bitwise
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
bitwise_not
()
out2
=
x
.
bitwise_and
(
y
)
out3
=
x
.
bitwise_or
(
y
)
out4
=
x
.
bitwise_xor
(
y
)
return
out1
,
out2
,
out3
,
out4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1
,
-
2
,
3
],
dtype
=
torch
.
int8
),
torch
.
tensor
([
1
,
0
,
3
],
dtype
=
torch
.
int8
),
))
# cauchy_ is not supported yet
def
test_ceil
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
ceil
()
return
out1
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
),
))
\ No newline at end of file
test/ut/retiarii/test_convert_models.py
0 → 100644
View file @
58d5c2fa
test/ut/retiarii/test_convert_operators.py
0 → 100644
View file @
58d5c2fa
'''
The tests in this file is copied and transformed from
`https://github.com/pytorch/pytorch/blob/master/test/onnx/test_operators.py`
'''
import
os
import
sys
import
unittest
from
typing
import
(
Dict
)
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
# following pytorch v1.7.1
class
TestOperators
(
unittest
.
TestCase
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
#print(model_code)
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
if
check_value
:
try
:
self
.
assertEqual
(
len
(
converted_output
),
len
(
expected_output
))
for
a
,
b
in
zip
(
converted_output
,
expected_output
):
torch
.
eq
(
a
,
b
)
except
:
self
.
assertEqual
(
converted_output
,
expected_output
)
return
converted_model
def
test_basic_basic
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
-
torch
.
sigmoid
(
torch
.
tanh
(
x
*
(
x
+
y
)))
return
out
x
=
torch
.
tensor
([
0.4
],
requires_grad
=
True
)
y
=
torch
.
tensor
([
0.7
],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_view
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
view
(
1
,
1
)
return
out
x
=
torch
.
tensor
([
0.0
],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_index
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
[
0
]
return
out
x
=
torch
.
tensor
([[
0.0
]],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_type_as
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
type_as
(
x
)
return
out
x
=
torch
.
tensor
([
0.0
],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_addconstant
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
+
1
return
out
x
=
torch
.
randn
(
2
,
3
,
requires_grad
=
True
).
double
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_add_broadcast
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
+
y
return
out
x
=
torch
.
randn
(
2
,
3
,
requires_grad
=
True
).
double
()
y
=
torch
.
randn
(
3
,
requires_grad
=
True
).
double
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_add_left_broadcast
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
+
y
return
out
x
=
torch
.
randn
(
3
,
requires_grad
=
True
).
double
()
y
=
torch
.
randn
(
2
,
3
,
requires_grad
=
True
).
double
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_add_size1_broadcast
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
+
y
return
out
x
=
torch
.
randn
(
2
,
3
,
requires_grad
=
True
).
double
()
y
=
torch
.
randn
(
2
,
1
,
requires_grad
=
True
).
double
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_add_size1_right_broadcast
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
+
y
return
out
x
=
torch
.
randn
(
2
,
3
,
requires_grad
=
True
).
double
()
y
=
torch
.
randn
(
3
,
requires_grad
=
True
).
double
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_add_size1_singleton_broadcast
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
+
y
return
out
x
=
torch
.
randn
(
2
,
3
,
requires_grad
=
True
).
double
()
y
=
torch
.
randn
(
1
,
3
,
requires_grad
=
True
).
double
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_rsub
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
1
-
x
return
out
x
=
torch
.
randn
(
2
,
3
,
requires_grad
=
True
).
double
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_transpose
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
transpose
(
0
,
1
).
transpose
(
1
,
0
)
return
out
x
=
torch
.
tensor
([[
0.0
,
1.0
],
[
2.0
,
3.0
]],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_chunk
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
chunk
(
2
)
return
out
x
=
torch
.
tensor
([
0.0
,
1.0
,
2.0
],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_split
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
split
(
x
,
2
,
1
)
return
out
x
=
torch
.
tensor
([[
0.0
,
1.0
,
1.0
,
0.0
,
2.0
,
2.0
],
[
2.0
,
3.0
,
3.0
,
2.0
,
1.0
,
1.0
]])
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_split_with_sizes
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
split
(
x
,
[
2
,
1
,
3
],
1
)
return
out
x
=
torch
.
tensor
([[
0.0
,
1.0
,
1.0
,
0.0
,
2.0
,
2.0
],
[
2.0
,
3.0
,
3.0
,
2.0
,
1.0
,
1.0
]])
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
@
unittest
.
skip
(
'cannot be parsed by jit'
)
def
test_basic_concat2
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
inputs
):
out
=
torch
.
cat
(
inputs
,
1
)
return
out
x
=
torch
.
randn
(
2
,
3
)
y
=
torch
.
randn
(
2
,
3
)
self
.
checkExportImport
(
SimpleOp
(),
((
x
,
y
),
))
def
test_basic_addmm
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out
=
torch
.
addmm
(
torch
.
addmm
(
z
,
x
,
y
),
x
,
y
)
return
out
m1
=
torch
.
randn
(
2
,
3
,
requires_grad
=
True
)
m2
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
m3
=
torch
.
randn
(
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
m1
,
m2
,
m3
,
))
def
test_basic_permute2
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
permute
(
0
,
1
,
4
,
2
,
5
,
3
)
return
out
x
=
torch
.
tensor
([[[[[[
0.0
]]]]]],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_params
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
-
torch
.
sigmoid
(
torch
.
tanh
(
x
*
(
x
+
y
)))
return
out
x
=
torch
.
tensor
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
requires_grad
=
True
)
y
=
torch
.
nn
.
Parameter
(
torch
.
tensor
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
requires_grad
=
True
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_params_onnx_irv4
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
-
torch
.
sigmoid
(
torch
.
tanh
(
x
*
(
x
+
y
)))
return
out
x
=
torch
.
tensor
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
requires_grad
=
True
)
y
=
torch
.
nn
.
Parameter
(
torch
.
tensor
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
requires_grad
=
True
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_clip
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
clamp
(
x
,
min
=-
0.5
,
max
=
0.5
)
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_clip_min
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
clamp
(
min
=-
0.1
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_clip_max
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
clamp
(
max
=
0.1
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
@
unittest
.
skip
(
'cannot be parsed by jit'
)
def
test_basic_hardtanh
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
nn
.
Hardtanh
(
-
0.5
,
0.5
)(
x
)
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_full
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
full
(
x
.
shape
,
2.
,
dtype
=
torch
.
float32
,
layout
=
torch
.
strided
,
device
=
torch
.
device
(
'cpu'
))
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_full_like
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
full_like
(
x
,
2
,
memory_format
=
torch
.
preserve_format
)
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_max
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
torch
.
max
(
x
,
y
)
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
y
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_min
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
torch
.
min
(
x
,
y
)
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
y
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_mean
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
mean
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_reduced_mean
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
mean
(
x
,
dim
=
2
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_reduced_mean_keepdim
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
mean
(
x
,
dim
=
(
2
,
3
),
keepdim
=
True
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_sum
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
sum
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_reduced_sum
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
sum
(
x
,
dim
=
(
1
,
2
))
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_reduced_sum_keepdim
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
sum
(
x
,
dim
=
2
,
keepdim
=
True
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_prod
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
prod
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_reduced_prod
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
prod
(
x
,
dim
=
2
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_reduced_prod_keepdim
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
prod
(
x
,
dim
=
2
,
keepdim
=
True
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_sqrt
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
sqrt
(
x
)
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_rsqrt
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
rsqrt
(
x
)
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_equal
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
==
y
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
1
,
requires_grad
=
False
).
int
()
y
=
torch
.
randn
(
1
,
4
,
requires_grad
=
False
).
int
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_lt
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
<
y
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
1
,
requires_grad
=
False
).
int
()
y
=
torch
.
randn
(
1
,
4
,
requires_grad
=
False
).
int
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_gt
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
>
y
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
1
,
requires_grad
=
False
).
int
()
y
=
torch
.
randn
(
1
,
4
,
requires_grad
=
False
).
int
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_le
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
<=
y
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
False
).
int
()
y
=
torch
.
randn
(
3
,
4
,
requires_grad
=
False
).
int
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_ge
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
>=
y
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
False
).
int
()
y
=
torch
.
randn
(
3
,
4
,
requires_grad
=
False
).
int
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_exp
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
exp
()
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_sin
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
sin
()
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_cos
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
cos
()
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_tan
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
tan
()
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_asin
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
asin
()
return
out
x
=
torch
.
rand
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_acos
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
acos
()
return
out
x
=
torch
.
rand
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_slice
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
[:,
1
:
2
]
return
out
x
=
torch
.
rand
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_slice_dynamic
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
[
x
.
size
(
0
):,
x
.
size
(
1
)
-
3
]
return
out
x
=
torch
.
rand
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_sign
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
sign
()
return
out
x
=
torch
.
rand
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_narrow
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
narrow
(
x
,
0
,
0
,
2
)
return
out
x
=
torch
.
randn
(
3
,
3
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_atan
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
atan
()
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_view_flatten
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
view
(
x
.
size
()[
0
],
x
.
numel
()
//
x
.
size
()[
0
])
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_flatten
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
flatten
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_flatten2D
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
flatten
(
x
,
1
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_isnan
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
isnan
(
x
)
return
out
x
=
torch
.
tensor
([
1
,
float
(
'nan'
),
2
])
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_argmax
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
argmax
(
x
,
dim
=
1
)
return
out
x
=
torch
.
randn
(
4
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_pow
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
.
pow
(
y
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
y
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_repeat
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
repeat
(
1
,
2
,
3
,
4
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_repeat_dim_overflow
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
repeat
(
1
,
2
,
3
,
4
)
return
out
x
=
torch
.
randn
(
1
,
2
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_norm_p1
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
norm
(
p
=
1
,
dim
=
2
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_norm_p2
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
norm
(
p
=
2
,
dim
=
2
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_upsample_nearest_size
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
nn
.
functional
.
interpolate
(
x
,
size
=
16
,
mode
=
'nearest'
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_unsqueeze
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
unsqueeze
(
len
(
x
.
shape
))
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_implicit_expand
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
+
1
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_reduce_sum_negative_indices
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
sum
(
-
1
)
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_randn
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
randn
(
1
,
2
,
3
,
4
)
+
x
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_rand
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
rand
(
1
,
2
,
3
,
4
)
+
x
return
out
x
=
torch
.
rand
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_empty_like
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
empty_like
(
x
)
return
out
x
=
torch
.
randn
(
5
,
8
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_empty_like_opset7
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
empty_like
(
x
)
return
out
x
=
torch
.
randn
(
5
,
8
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_zeros_like
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
zeros_like
(
x
)
return
out
x
=
torch
.
randn
(
5
,
8
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_ones_like
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
ones_like
(
x
)
return
out
x
=
torch
.
randn
(
6
,
10
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_expand
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
expand
(
4
,
6
,
2
)
return
out
x
=
torch
.
randn
(
6
,
1
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_ne
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
torch
.
ne
(
x
,
y
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
1
,
requires_grad
=
False
).
int
()
y
=
torch
.
randn
(
1
,
4
,
requires_grad
=
False
).
int
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_reducemax
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
max
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_reducemin
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
min
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_erf
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
erf
()
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_dropout
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
max
(
torch
.
nn
.
functional
.
dropout
(
x
,
training
=
False
))
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_dropout_default
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
max
(
torch
.
nn
.
functional
.
dropout
(
x
,))
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
),
check_value
=
False
)
def
test_basic_dropout_training
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
max
(
torch
.
nn
.
functional
.
dropout
(
x
))
return
out
x
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
),
check_value
=
False
)
def
test_basic_nonzero
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
nonzero
(
x
)
return
out
x
=
torch
.
tensor
([[[
2.
,
2.
],
[
1.
,
0.
]],
[[
0.
,
0.
],
[
1.
,
1.
]]],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_gather
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
data
,
index
):
out
=
data
.
gather
(
1
,
index
)
return
out
data
=
torch
.
randn
(
3
,
4
,
3
,
requires_grad
=
True
)
index
=
torch
.
tensor
([
2
,
0
]).
view
(
1
,
2
,
1
).
expand
(
3
,
2
,
3
)
self
.
checkExportImport
(
SimpleOp
(),
(
data
,
index
,
))
def
test_basic_gather_opset11
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
data
,
index
):
out
=
data
.
gather
(
1
,
index
)
return
out
data
=
torch
.
randn
(
3
,
4
,
3
,
requires_grad
=
True
)
index
=
torch
.
tensor
([
2
,
0
]).
view
(
1
,
2
,
1
).
expand
(
3
,
2
,
3
)
self
.
checkExportImport
(
SimpleOp
(),
(
data
,
index
,
))
def
test_basic_scatter_add
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
data
,
indices
,
values
):
out
=
data
.
scatter_add
(
1
,
indices
,
values
)
return
out
data
=
torch
.
tensor
([[
0.
,
0.
,
0.
],
[
0.
,
0.
,
0.
],
[
0.
,
0.
,
0.
]])
indices
=
torch
.
tensor
([[
1
,
0
],
[
0
,
1
],
[
0
,
1
]],
dtype
=
torch
.
int64
)
values
=
torch
.
tensor
([[
1.0
,
1.1
],
[
2.0
,
2.1
],
[
3.0
,
3.1
]])
self
.
checkExportImport
(
SimpleOp
(),
(
data
,
indices
,
values
,
))
def
test_basic_scatter_add_opset11
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
data
,
indices
,
values
):
out
=
data
.
scatter_add
(
1
,
indices
,
values
)
return
out
data
=
torch
.
tensor
([[
0.
,
0.
,
0.
],
[
0.
,
0.
,
0.
],
[
0.
,
0.
,
0.
]])
indices
=
torch
.
tensor
([[
1
,
0
],
[
0
,
1
],
[
0
,
1
]],
dtype
=
torch
.
int64
)
values
=
torch
.
tensor
([[
1.0
,
1.1
],
[
2.0
,
2.1
],
[
3.0
,
3.1
]])
self
.
checkExportImport
(
SimpleOp
(),
(
data
,
indices
,
values
,
))
def
test_basic_master_opset
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
x
+
y
return
out
x
=
torch
.
randn
(
2
,
3
).
float
()
y
=
torch
.
randn
(
2
,
3
).
float
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_std
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
std
(
x
,
dim
=
(
0
,
1
),
unbiased
=
True
,
keepdim
=
True
)
return
out
x
=
torch
.
randn
(
2
,
3
,
4
).
float
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_cumsum
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
cumsum
(
x
,
dim
=
1
)
return
out
x
=
torch
.
randn
(
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_pixel_shuffle
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
pixel_shuffle
(
x
,
upscale_factor
=
2
)
return
out
x
=
torch
.
randn
(
2
,
8
,
3
,
4
).
float
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
@
unittest
.
skip
(
'skip as torch.norm is called with prim::CallFunction, also torch.norm is deprecated'
)
def
test_basic_frobenius_norm
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
norm
(
x
,
p
=
"fro"
,
dim
=
(
0
,
1
),
keepdim
=
True
)
return
out
x
=
torch
.
randn
(
2
,
3
,
4
).
float
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_unfold
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
unfold
(
dimension
=
2
,
size
=
2
,
step
=
2
)
return
out
x
=
torch
.
randn
(
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_remainder
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
torch
.
remainder
(
x
,
y
)
return
out
x
=
torch
.
randn
(
2
,
3
,
4
)
y
=
torch
.
randn
(
2
,
1
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_fmod
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
torch
.
fmod
(
x
,
y
)
return
out
x
=
torch
.
randn
(
2
,
3
,
4
)
y
=
torch
.
randn
(
2
,
1
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
))
def
test_basic_gelu
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
nn
.
functional
.
gelu
(
x
)
return
out
x
=
torch
.
randn
(
2
,
3
,
4
,
5
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
@
unittest
.
skip
(
'skip as it is called with prim::CallFunction, and unknown func definition'
)
def
test_basic_unique
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
unique
(
x
,
dim
=
0
,
sorted
=
True
,
return_inverse
=
False
,
return_counts
=
True
)
return
out
x
=
torch
.
randint
(
3
,
(
2
,
3
,
4
,
5
)).
float
()
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_meshgrid
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out
=
torch
.
meshgrid
(
x
,
y
,
z
)
return
out
x
=
torch
.
ones
(
3
,
requires_grad
=
True
)
y
=
torch
.
zeros
(
4
,
requires_grad
=
True
)
z
=
torch
.
ones
(
5
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
y
,
z
,
))
def
test_basic_topk
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
k
):
out
=
torch
.
topk
(
x
,
k
)
return
out
x
=
torch
.
arange
(
1.
,
6.
,
requires_grad
=
True
)
k
=
torch
.
tensor
(
3
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
k
,
))
def
test_basic_topk_smallest_unsorted
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
k
):
out
=
torch
.
topk
(
x
,
k
,
largest
=
False
,
sorted
=
False
)
return
out
x
=
torch
.
arange
(
1.
,
6.
,
requires_grad
=
True
)
k
=
torch
.
tensor
(
3
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
k
,
))
def
test_basic_baddbmm
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
b1
,
b2
):
out
=
torch
.
baddbmm
(
x
,
b1
,
b2
)
return
out
x
=
torch
.
randn
(
10
,
3
,
5
)
b1
=
torch
.
randn
(
10
,
3
,
4
)
b2
=
torch
.
randn
(
10
,
4
,
5
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
b1
,
b2
,
))
def
test_basic_round
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
round
(
x
)
return
out
x
=
torch
.
tensor
([
0.9920
,
-
1.0362
,
-
1.5000
,
2.5000
],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_dim
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
scalar_tensor
(
x
.
dim
())
return
out
x
=
torch
.
ones
((
2
,
2
),
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_det
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
det
(
x
)
return
out
x
=
torch
.
randn
(
2
,
3
,
5
,
5
,
device
=
torch
.
device
(
'cpu'
))
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
# the followings are more complex tests
def
test_mm
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out
=
torch
.
mm
(
x
,
y
)
return
out
m1
=
torch
.
randn
(
2
,
3
,
requires_grad
=
True
)
m2
=
torch
.
randn
(
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
m1
,
m2
))
def
test_basic_pad
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
ReflectionPad2d
((
2
,
3
,
0
,
1
))
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
tensor
([[[[
0.0
,
1.0
,
1.0
,
1.0
],
[
2.0
,
3.0
,
7.0
,
7.0
]]]],
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_batchnorm
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
BatchNorm2d
(
2
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
ones
(
2
,
2
,
2
,
2
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_batchnorm_1d
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
BatchNorm1d
(
2
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
ones
(
2
,
2
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_conv
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
Conv2d
(
16
,
13
,
3
,
bias
=
False
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
ones
(
20
,
16
,
50
,
40
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_conv_onnx_irv4_opset8
(
self
):
# This test point checks that for opset 8 (or lower), even if
# keep_initializers_as_inputs is set to False, it is ignored,
# and initializers are listed as ONNX graph input, in accordance
# with ONNX IR v3 semantics (which apply to opset version <= 8).
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
Conv2d
(
2
,
4
,
3
,
bias
=
False
)
self
.
m
.
weight
.
data
.
fill_
(
1.0
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
ones
(
1
,
2
,
5
,
7
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_convtranspose
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
ConvTranspose2d
(
3
,
3
,
3
,
stride
=
3
,
bias
=
False
,
padding
=
1
,
output_padding
=
2
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
ones
(
2
,
3
,
4
,
5
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,))
def
test_basic_maxpool
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
MaxPool1d
(
3
,
stride
=
2
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_maxpool_dilations
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
MaxPool1d
(
2
,
stride
=
1
,
dilation
=
2
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_avg_pool2d
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
AvgPool2d
(
3
,
stride
=
2
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
20
,
16
,
50
,
32
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
@
unittest
.
skip
(
'jit error: "Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]"'
)
def
test_basic_maxpool_indices
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
MaxPool1d
(
3
,
stride
=
2
,
return_indices
=
True
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
@
unittest
.
skip
(
"jit error: Tried to access nonexistent attribute or method 'at' of type '__torch__.test_convert_operators.MyFun'"
)
def
test_at_op
(
self
):
from
torch.autograd
import
Function
x
=
torch
.
randn
(
3
,
4
)
class
MyFun
(
Function
):
@
staticmethod
def
symbolic
(
g
,
x
):
return
g
.
at
(
"add"
,
x
,
x
)
@
staticmethod
def
forward
(
ctx
,
x
):
return
x
+
x
class
MyModule
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
MyFun
.
apply
(
x
)
self
.
checkExportImport
(
MyModule
(),
x
)
def
test_basic_logsoftmax
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
LogSoftmax
(
dim
=
3
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_elu
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
ELU
()
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_selu
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
SELU
()
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_upsample_nearest_scale
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.
,
mode
=
'nearest'
,
recompute_scale_factor
=
False
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_upsample_nearest_scale_default_scale_factor
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
torch
.
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.
,
mode
=
'nearest'
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_batchnorm_noaffine
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
BatchNorm2d
(
128
,
affine
=
False
,
momentum
=
0.3
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
128
,
128
,
1
,
1
,
requires_grad
=
True
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_embedding_bags
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
EmbeddingBag
(
10
,
8
)
def
forward
(
self
,
x
,
y
):
out
=
self
.
m
(
x
,
y
)
return
out
input
=
torch
.
tensor
([
1
,
2
,
3
,
4
]).
long
()
offset
=
torch
.
tensor
([
0
]).
long
()
self
.
checkExportImport
(
SimpleOp
(),
(
input
,
offset
,
))
def
test_basic_rrelu
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
RReLU
()
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_prelu
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
PReLU
(
2
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_log_sigmoid
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
LogSigmoid
()
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
1
,
2
,
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_basic_linear
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
Linear
(
4
,
5
,
bias
=
True
)
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
3
,
4
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
def
test_retain_param_name_disabled
(
self
):
class
MyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MyModule
,
self
).
__init__
()
self
.
fc1
=
nn
.
Linear
(
4
,
5
,
bias
=
False
)
self
.
fc1
.
weight
.
data
.
fill_
(
2.
)
self
.
fc2
=
nn
.
Linear
(
5
,
6
,
bias
=
False
)
self
.
fc2
.
weight
.
data
.
fill_
(
3.
)
def
forward
(
self
,
x
):
return
self
.
fc2
(
self
.
fc1
(
x
))
x
=
torch
.
randn
(
3
,
4
).
float
()
self
.
checkExportImport
(
MyModule
(),
(
x
,
))
@
unittest
.
skip
(
'Segmentation fault'
)
def
test_dict
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
x_in
:
Dict
):
x_out
=
{}
x_out
[
"test_key_out"
]
=
torch
.
add
(
x_in
[
list
(
x_in
.
keys
())[
0
]],
list
(
x_in
.
keys
())[
0
])
return
x_out
x
=
{
torch
.
tensor
(
1.
):
torch
.
randn
(
1
,
2
,
3
)}
self
.
checkExportImport
(
MyModel
(),
(
x
,
))
def
test_arange_dynamic
(
self
):
class
TestModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
out
=
torch
.
arange
(
input
.
shape
[
0
],
input
.
shape
[
0
]
+
5
,
0.5
)
return
out
input
=
torch
.
randn
(
5
,
3
,
2
)
self
.
checkExportImport
(
TestModel
(),
(
input
,
))
def
test_bitshift
(
self
):
class
BitshiftModel
(
nn
.
Module
):
def
forward
(
self
,
input
,
input2
):
return
input
>>
1
,
input2
>>
2
input
=
torch
.
arange
(
24
,
dtype
=
torch
.
float32
).
reshape
(
3
,
4
,
2
)
input2
=
torch
.
arange
(
24
,
dtype
=
torch
.
uint8
).
reshape
(
3
,
4
,
2
)
self
.
checkExportImport
(
BitshiftModel
(),
(
input
,
input2
,
))
def
test_layer_norm_aten
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
LayerNorm
([
10
,
10
])
def
forward
(
self
,
x
):
out
=
self
.
m
(
x
)
return
out
x
=
torch
.
randn
(
20
,
5
,
10
,
10
)
self
.
checkExportImport
(
SimpleOp
(),
(
x
,
))
\ No newline at end of file
test/ut/retiarii/test_convert_pytorch.py
0 → 100644
View file @
58d5c2fa
'''
The tests in this file is copied and transformed from
https://github.com/pytorch/pytorch/blob/master/test/onnx/test_pytorch_onnx_onnxruntime.py
'''
import
os
import
sys
import
unittest
from
typing
import
(
Dict
)
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
class
TestPytorch
(
unittest
.
TestCase
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
run_test
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
from
.inject_nn
import
remove_inject_pytorch_nn
remove_inject_pytorch_nn
()
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
if
check_value
:
try
:
self
.
assertEqual
(
len
(
converted_output
),
len
(
expected_output
))
for
a
,
b
in
zip
(
converted_output
,
expected_output
):
torch
.
eq
(
a
,
b
)
except
:
self
.
assertEqual
(
converted_output
,
expected_output
)
return
converted_model
def
test_embedding_model_with_external_data
(
self
):
class
LargeModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
LargeModel
,
self
).
__init__
()
dim
=
15
n
=
4
*
100
self
.
emb
=
nn
.
Embedding
(
n
,
dim
)
self
.
lin1
=
nn
.
Linear
(
dim
,
1
)
self
.
seq
=
nn
.
Sequential
(
self
.
emb
,
self
.
lin1
,
)
def
forward
(
self
,
input
):
return
self
.
seq
(
input
)
model
=
LargeModel
()
x
=
torch
.
tensor
([
2
],
dtype
=
torch
.
long
)
self
.
run_test
(
model
,
(
x
,
))
@
unittest
.
skip
(
'skip for now, as it needs inject_nn'
)
def
test_mobilenet_v2_with_external_data
(
self
):
model
=
torchvision
.
models
.
mobilenet_v2
(
pretrained
=
True
)
x
=
torch
.
randn
(
2
,
3
,
224
,
224
,
requires_grad
=
True
)
# We are turning off Onnx Runtime optimization off in this test,
# because external data format is not supported to in ORT optimizer.
# Once that support is added, we can set ort_optim_on=True (default).
self
.
run_test
(
model
,
(
x
,
))
def
test_attribute_with_external_data
(
self
):
class
LargeModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
+
torch
.
ones
(
2
,
1024
)
x
=
torch
.
randn
(
2
,
1
)
self
.
run_test
(
LargeModel
(),
(
x
,
))
@
unittest
.
skip
(
'skip as it has loop'
)
def
test_subgraph_with_external_data
(
self
):
class
LargeModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
for
i
in
range
(
x
.
size
(
0
)):
x
=
x
+
torch
.
ones
(
2
,
1024
)
return
x
x
=
torch
.
randn
(
2
,
1
)
self
.
run_test
((
LargeModel
()),
(
x
,
))
def
test_fuse_conv_bn1d
(
self
):
class
Fuse
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Fuse
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv1d
(
16
,
33
,
3
,
stride
=
2
)
self
.
bn
=
nn
.
BatchNorm1d
(
33
)
def
forward
(
self
,
x
):
out
=
self
.
conv
(
x
)
return
self
.
bn
(
out
)
model
=
Fuse
()
x
=
torch
.
randn
(
20
,
16
,
50
,
requires_grad
=
True
)
self
.
run_test
(
model
,
(
x
,))
def
test_fuse_conv_bn2d
(
self
):
class
Fuse
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Fuse
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
2
,
kernel_size
=
1
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn
=
nn
.
BatchNorm2d
(
2
)
def
forward
(
self
,
x
):
out
=
self
.
conv
(
x
)
return
self
.
bn
(
out
)
model
=
Fuse
()
x
=
torch
.
randn
(
2
,
3
,
2
,
2
,
requires_grad
=
True
)
self
.
run_test
(
model
,
(
x
,))
def
test_fuse_conv_bn3d
(
self
):
class
Fuse
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Fuse
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv3d
(
3
,
2
,
(
3
,
5
,
2
),
stride
=
(
2
,
1
,
1
),
padding
=
(
3
,
2
,
0
),
bias
=
False
)
self
.
bn
=
nn
.
BatchNorm3d
(
2
)
def
forward
(
self
,
x
):
out
=
self
.
conv
(
x
)
return
self
.
bn
(
out
)
model
=
Fuse
()
x
=
torch
.
randn
(
2
,
3
,
10
,
50
,
100
,
requires_grad
=
True
)
self
.
run_test
(
model
,
(
x
,))
@
unittest
.
skip
(
'have not supported register_buffer yet'
)
def
test_reshape_constant_fold
(
self
):
class
Reshape
(
nn
.
Module
):
def
__init__
(
self
,
):
super
(
Reshape
,
self
).
__init__
()
self
.
register_buffer
(
"weight"
,
torch
.
ones
(
5
))
def
forward
(
self
,
x
):
scale_1
=
self
.
weight
.
reshape
(
1
,
-
1
,
1
,
1
)
return
x
*
scale_1
x
=
torch
.
randn
(
4
,
5
)
self
.
run_test
(
Reshape
(),
(
x
,))
def
run_word_language_model
(
self
,
model_name
):
ntokens
=
50
emsize
=
5
nhid
=
5
nlayers
=
5
dropout
=
0.2
tied
=
False
batchsize
=
5
model
=
word_language_model
.
RNNModel
(
model_name
,
ntokens
,
emsize
,
nhid
,
nlayers
,
dropout
,
tied
,
batchsize
)
x
=
torch
.
arange
(
0
,
ntokens
).
long
().
view
(
-
1
,
batchsize
)
# Only support CPU version, since tracer is not working in GPU RNN.
self
.
run_test
(
model
,
(
x
,
model
.
hidden
))
def
get_image_from_url
(
self
,
url
,
size
=
(
300
,
200
)):
import
os
from
urllib.parse
import
urlsplit
from
urllib
import
request
from
PIL
import
Image
from
torchvision
import
transforms
from
torch._utils_internal
import
get_writable_path
filename
=
os
.
path
.
basename
(
urlsplit
(
url
)[
2
])
data_dir
=
get_writable_path
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
)))
path
=
os
.
path
.
join
(
data_dir
,
filename
)
data
=
request
.
urlopen
(
url
,
timeout
=
15
).
read
()
with
open
(
path
,
'wb'
)
as
f
:
f
.
write
(
data
)
image
=
Image
.
open
(
path
).
convert
(
"RGB"
)
image
=
image
.
resize
(
size
,
Image
.
BILINEAR
)
to_tensor
=
transforms
.
ToTensor
()
return
to_tensor
(
image
)
def
get_test_images
(
self
):
image_url
=
"http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image
=
self
.
get_image_from_url
(
url
=
image_url
,
size
=
(
100
,
320
))
image_url2
=
"https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2
=
self
.
get_image_from_url
(
url
=
image_url2
,
size
=
(
250
,
380
))
return
[
image
],
[
image2
]
@
unittest
.
skip
(
'does not support `if A and/or B`'
)
def
test_faster_rcnn
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
model
=
torchvision
.
models
.
detection
.
faster_rcnn
.
fasterrcnn_resnet50_fpn
(
pretrained
=
True
,
min_size
=
200
,
max_size
=
300
)
model
.
eval
()
x
=
torch
.
randn
(
2
,
3
,
200
,
300
,
requires_grad
=
True
)
self
.
run_test
(
model
,
(
x
,))
dummy_image
=
[
torch
.
ones
(
3
,
100
,
100
)
*
0.3
]
images
,
test_images
=
self
.
get_test_images
()
self
.
run_test
(
model
,
(
images
,))
self
.
run_test
(
model
,
(
dummy_image
,))
@
unittest
.
skip
(
'does not support `if A and/or B`'
)
def
test_mask_rcnn
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
model
=
torchvision
.
models
.
detection
.
mask_rcnn
.
maskrcnn_resnet50_fpn
(
pretrained
=
True
,
min_size
=
200
,
max_size
=
300
)
images
,
test_images
=
self
.
get_test_images
()
self
.
run_test
(
model
,
(
images
,))
dummy_image
=
[
torch
.
ones
(
3
,
100
,
100
)
*
0.3
]
self
.
run_test
(
model
,
(
dummy_image
,))
@
unittest
.
skip
(
'does not support `if A and/or B`'
)
def
test_keypoint_rcnn
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
model
=
torchvision
.
models
.
detection
.
keypoint_rcnn
.
keypointrcnn_resnet50_fpn
(
pretrained
=
True
,
min_size
=
200
,
max_size
=
300
)
images
,
test_images
=
self
.
get_test_images
()
self
.
run_test
(
model
,
(
images
,))
dummy_images
=
[
torch
.
ones
(
3
,
100
,
100
)
*
0.3
]
self
.
run_test
(
model
,
(
dummy_images
,))
def
test_shufflenet_v2_dynamic_axes
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
model
=
torchvision
.
models
.
shufflenet_v2_x0_5
(
pretrained
=
True
)
dummy_input
=
torch
.
randn
(
1
,
3
,
224
,
224
,
requires_grad
=
True
)
test_inputs
=
torch
.
randn
(
3
,
3
,
224
,
224
,
requires_grad
=
True
)
self
.
run_test
(
model
,
(
dummy_input
,))
@
unittest
.
skip
(
''
)
def
test_word_language_model_RNN_TANH
(
self
):
self
.
run_word_language_model
(
"RNN_TANH"
)
@
unittest
.
skip
(
''
)
def
test_word_language_model_RNN_RELU
(
self
):
self
.
run_word_language_model
(
"RNN_RELU"
)
@
unittest
.
skip
(
''
)
def
test_word_language_model_LSTM
(
self
):
self
.
run_word_language_model
(
"LSTM"
)
@
unittest
.
skip
(
''
)
def
test_word_language_model_GRU
(
self
):
self
.
run_word_language_model
(
"GRU"
)
def
test_index_1d
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
input
[
0
]
m1
=
torch
.
randn
(
3
,
4
,
5
,
6
,
7
)
self
.
run_test
(
MyModel
(),
(
m1
,
))
def
test_index_2d_1dimslice
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
input
[
0
:
1
,
:]
m1
=
torch
.
randn
(
3
,
4
,
5
,
6
,
7
)
self
.
run_test
(
MyModel
(),
(
m1
,
))
def
test_index_2d_sliceint
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
input
[
1
,
:]
m1
=
torch
.
randn
(
3
,
4
,
5
,
6
,
7
)
self
.
run_test
(
MyModel
(),
(
m1
,
))
def
test_index_2d_neg_slice
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
input
[
0
:
-
1
,
:]
m1
=
torch
.
randn
(
3
,
4
,
5
,
6
,
7
)
self
.
run_test
(
MyModel
(),
(
m1
,
))
def
test_index_mask
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
input
[
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
uint8
)]
m1
=
torch
.
randn
(
3
,
4
,
5
,
6
,
7
)
self
.
run_test
(
MyModel
(),
(
m1
,
))
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
input
[
torch
.
tensor
([
0
,
1
,
0
],
dtype
=
torch
.
bool
)]
m1
=
torch
.
randn
(
3
,
4
,
5
,
6
,
7
)
self
.
run_test
(
MyModel
(),
(
m1
,
))
def
test_data
(
self
):
class
Data
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
new_zeros
(
x
.
data
.
size
())
x
=
torch
.
randn
(
3
,
4
)
self
.
run_test
(
Data
(),
(
x
,
))
def
test_index_mask_nd
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
input
[
input
>
0
]
m1
=
torch
.
randn
(
3
,
4
,
5
,
6
,
7
)
self
.
run_test
(
MyModel
(),
(
m1
,
))
@
unittest
.
skip
(
"Tried to access nonexistent attribute or method 'keys' of type 'Tensor (inferred)'."
)
def
test_dict
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
x_in
):
x_out
=
{}
x_out
[
"test_key_out"
]
=
torch
.
add
(
x_in
[
list
(
x_in
.
keys
())[
0
]],
list
(
x_in
.
keys
())[
0
])
return
x_out
x
=
{
torch
.
tensor
(
1.
):
torch
.
randn
(
1
,
2
,
3
)}
self
.
run_test
(
MyModel
(),
(
x
,
{}))
@
unittest
.
skip
(
"Unsupported operation: indexing tensor with unsupported index type 'str'."
)
def
test_dict_str
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
x_in
):
x_out
=
{}
x_out
[
"test_key_out"
]
=
torch
.
add
(
x_in
[
"test_key_in"
],
2.
)
return
x_out
x
=
{
"test_key_in"
:
torch
.
randn
(
1
,
2
,
3
)}
self
.
run_test
(
MyModel
(),
(
x
,
{}))
@
unittest
.
skip
(
'Convert graph error'
)
def
test_optional_inputs_with_no_optionals
(
self
):
class
NoOptionalModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
input
# Without empty optional arguments dictionary
x
=
torch
.
randn
(
2
,
3
)
self
.
run_test
(
NoOptionalModel
(),
(
x
,))
# With empty optional arguments dictionary
y
=
torch
.
randn
(
2
,
3
)
self
.
run_test
(
NoOptionalModel
(),
(
y
,
{}))
# NOTE: torch script gets an incorrect graph...
def
test_optional_inputs_with_mixed_optionals
(
self
):
class
MixedModel
(
nn
.
Module
):
def
forward
(
self
,
x
:
'Tensor'
,
y
:
'Tensor'
,
z
:
'Tensor'
):
if
y
is
not
None
:
return
x
+
y
if
z
is
not
None
:
return
x
+
z
return
x
x
=
torch
.
randn
(
2
,
3
)
y
=
torch
.
randn
(
2
,
3
)
z
=
torch
.
randn
(
2
,
3
)
# Without optional arguments dictionary
self
.
run_test
(
MixedModel
(),
(
x
,
y
,
None
))
#self.run_test(MixedModel(), (x, None, z, ))
# With optional arguments dictionary
#self.run_test(MixedModel(), (x, {'y': y, 'z': None}))
#self.run_test(MixedModel(), (x, {'y': None, 'z': z}))
#self.run_test(MixedModel(), (x, {'z': z}))
#self.run_test(MixedModel(), (x, {'y': y}))
@
unittest
.
skip
(
'torch script gets an incorrect graph...'
)
def
test_optional_inputs_with_all_optionals
(
self
):
class
AllOptionalModel
(
nn
.
Module
):
def
forward
(
self
,
y
,
z
):
if
y
is
not
None
:
return
y
if
z
is
not
None
:
return
z
y
=
torch
.
randn
(
2
,
3
)
# Without optional arguments dictionary
self
.
run_test
(
AllOptionalModel
(),
(
y
,
None
))
# With optional arguments dictionary
#self.run_test(AllOptionalModel(), {'y': y, 'z': None})
@
unittest
.
skip
(
'torch script gets an incorrect graph...'
)
def
test_none_as_input
(
self
):
class
Model
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
if
y
is
not
None
:
return
x
+
y
return
x
x
=
torch
.
randn
(
2
,
3
)
self
.
run_test
(
Model
(),
(
x
,
None
))
@
unittest
.
skip
(
'jit cannot correctly deal with tuple as input argument'
)
def
test_none_as_tuple_input
(
self
):
class
Model
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
if
y
[
0
]
is
not
None
:
return
x
+
y
[
0
]
if
y
[
1
]
is
not
None
:
return
x
+
y
[
1
]
return
x
x
=
torch
.
randn
(
2
,
3
)
y
=
torch
.
randn
(
2
,
3
)
self
.
run_test
(
Model
(),
(
x
,
(
None
,
y
)))
def
test_cste_script
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
zeros
(
x
.
size
(
0
)),
torch
.
ones
((
x
.
size
(
1
),
x
.
size
(
0
)),
dtype
=
torch
.
int64
)
x
=
torch
.
randn
(
3
,
4
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_scalar_tensor
(
self
):
class
test
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
torch
.
scalar_tensor
(
input
.
size
(
0
)),
\
torch
.
scalar_tensor
(
input
.
size
(
1
),
dtype
=
torch
.
int64
)
x
=
torch
.
randn
(
2
,
3
,
4
)
y
=
torch
.
randn
(
7
,
8
,
9
)
model
=
test
()
self
.
run_test
(
model
,
(
x
,
))
def
test_tensor
(
self
):
class
ScalarInputModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
torch
.
tensor
(
input
.
shape
[
1
])
x
=
torch
.
randn
(
3
,
4
)
self
.
run_test
(
ScalarInputModel
(),
(
x
,
))
class
TensorInputModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
torch
.
tensor
([
input
.
shape
[
0
],
input
.
shape
[
1
]])
x
=
torch
.
randn
(
3
,
4
)
self
.
run_test
(
TensorInputModel
(),
(
x
,
))
class
FloatInputModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
torch
.
tensor
([
float
(
input
)])
x
=
torch
.
randn
(
1
)
self
.
run_test
(
FloatInputModel
(),
(
x
,
))
class
InputWithDtypeModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
torch
.
tensor
(
input
.
shape
[
1
],
dtype
=
torch
.
long
)
x
=
torch
.
randn
(
3
,
4
)
self
.
run_test
(
InputWithDtypeModel
(),
(
x
,
))
class
MixedInputModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
torch
.
tensor
([
input
.
shape
[
0
],
int
(
input
)])
x
=
torch
.
randn
(
1
)
self
.
run_test
(
MixedInputModel
(),
(
x
,
))
def
test_hardtanh
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
Hardtanh
(
-
1.5
,
2.5
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
arange
(
-
5
,
5
).
to
(
dtype
=
torch
.
float32
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_hardtanh_script_with_default_values
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
F
.
hardtanh
(
x
)
x
=
torch
.
arange
(
-
5
,
5
).
to
(
dtype
=
torch
.
float32
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_hardswish
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
Hardswish
()
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
rand
(
3
,
3
).
to
(
dtype
=
torch
.
float32
)
self
.
run_test
(
MyModel
(),
(
x
,
))
# Testing edge cases
x
=
torch
.
tensor
(
3
).
to
(
dtype
=
torch
.
float32
)
self
.
run_test
(
MyModel
(),
(
x
,
))
x
=
torch
.
tensor
(
-
3
).
to
(
dtype
=
torch
.
float32
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_hardswish_script
(
self
):
class
MyModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
F
.
hardswish
(
x
)
x
=
torch
.
rand
(
3
,
3
).
to
(
dtype
=
torch
.
float32
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_clamp
(
self
):
class
ClampModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
clamp
(
-
0.5
,
0.5
)
x
=
torch
.
randn
(
3
,
4
)
self
.
run_test
(
ClampModel
(),
(
x
,
))
class
ClampMinModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
clamp
(
min
=-
0.5
)
x
=
torch
.
randn
(
3
,
4
)
self
.
run_test
(
ClampMinModel
(),
(
x
,
))
class
ClampMaxModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
clamp
(
max
=
0.5
)
x
=
torch
.
randn
(
3
,
4
)
self
.
run_test
(
ClampMaxModel
(),
(
x
,
))
def
test_clamp_dyn
(
self
):
class
ClampMaxModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
clamp
(
None
,
x
.
size
(
0
))
x
=
torch
.
arange
(
16
).
view
(
4
,
4
).
float
()
self
.
run_test
(
ClampMaxModel
(),
(
x
,
))
class
ClampMinModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
clamp
(
x
.
size
(
0
),
None
)
x
=
torch
.
arange
(
16
).
view
(
4
,
4
).
float
()
self
.
run_test
(
ClampMinModel
(),
(
x
,
))
class
ClampMinMaxModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
clamp
(
x
.
size
(
0
),
x
.
size
(
1
))
x
=
torch
.
arange
(
16
).
view
(
2
,
8
).
float
()
self
.
run_test
(
ClampMinMaxModel
(),
(
x
,
))
def
test_full_trace
(
self
):
class
FullModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
full
((
3
,
4
),
x
,
dtype
=
torch
.
long
)
x
=
torch
.
tensor
(
12
)
self
.
run_test
(
FullModel
(),
(
x
,
))
def
test_full_script
(
self
):
class
FullModelScripting
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
full
((
3
,
4
),
x
,
dtype
=
torch
.
long
)
x
=
torch
.
tensor
(
12
)
self
.
run_test
(
FullModelScripting
(),
(
x
,
))
def
test_fuse_addmm
(
self
):
class
AddmmModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
mm
(
x
,
x
)
+
x
x
=
torch
.
ones
(
3
,
3
)
self
.
run_test
(
AddmmModel
(),
(
x
,
))
def
test_maxpool
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
MaxPool1d
(
2
,
stride
=
1
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_conv
(
self
):
class
TraceModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TraceModel
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv1d
(
16
,
33
,
3
,
stride
=
2
)
self
.
conv2
=
nn
.
Conv2d
(
16
,
33
,
(
3
,
5
),
stride
=
(
2
,
1
),
padding
=
(
4
,
2
),
dilation
=
(
3
,
1
))
self
.
conv3
=
nn
.
Conv3d
(
16
,
33
,
(
3
,
5
,
2
),
stride
=
(
2
,
1
,
1
),
padding
=
(
4
,
2
,
0
))
def
forward
(
self
,
input1
,
input2
,
input3
):
return
self
.
conv1
(
input1
),
self
.
conv2
(
input2
),
self
.
conv3
(
input3
)
x1
=
torch
.
randn
(
20
,
16
,
50
)
x2
=
torch
.
randn
(
20
,
16
,
50
,
100
)
x3
=
torch
.
randn
(
20
,
16
,
10
,
50
,
100
)
self
.
run_test
(
TraceModel
(),
(
x1
,
x2
,
x3
,
))
def
test_conv_shape_inference
(
self
):
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
self
.
conv2
=
nn
.
Conv2d
(
16
,
33
,
(
3
,
5
),
stride
=
(
2
,
1
),
padding
=
(
4
,
2
),
dilation
=
(
3
,
1
))
def
forward
(
self
,
input
):
return
self
.
conv2
(
input
)
+
2
x
=
torch
.
randn
(
20
,
16
,
50
,
100
)
self
.
run_test
(
Model
(),
(
x
,
))
def
test_conv_transpose
(
self
):
class
TraceModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TraceModel
,
self
).
__init__
()
self
.
conv1
=
nn
.
ConvTranspose1d
(
16
,
33
,
3
,
stride
=
2
)
self
.
conv2
=
nn
.
ConvTranspose2d
(
16
,
33
,
(
3
,
5
),
stride
=
(
2
,
1
),
padding
=
(
4
,
2
),
dilation
=
(
3
,
1
))
self
.
conv3
=
nn
.
ConvTranspose3d
(
16
,
33
,
(
3
,
5
,
2
),
stride
=
(
2
,
1
,
1
),
padding
=
(
4
,
2
,
0
))
def
forward
(
self
,
input1
,
input2
,
input3
):
return
self
.
conv1
(
input1
),
self
.
conv2
(
input2
),
self
.
conv3
(
input3
)
x1
=
torch
.
randn
(
20
,
16
,
50
)
x2
=
torch
.
randn
(
20
,
16
,
50
,
100
)
x3
=
torch
.
randn
(
20
,
16
,
10
,
50
,
100
)
self
.
run_test
(
TraceModel
(),
(
x1
,
x2
,
x3
,
))
# Conversion of Transpose depends on input shape to be known.
# The following test only works when onnx shape inference is enabled.
def
test_transpose_infer_shape
(
self
):
class
TransposeModule
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TransposeModule
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
1
,
3
,
stride
=
2
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
.
transpose
(
0
,
1
)
x
=
torch
.
randn
(
32
,
3
,
64
,
64
)
y
=
torch
.
randn
(
16
,
3
,
8
,
64
)
self
.
run_test
(
TransposeModule
(),
(
x
,
))
def
squeeze_model_tests
(
self
,
d
,
x1
):
class
Squeeze
(
nn
.
Module
):
def
__init__
(
self
,
d
):
super
(
Squeeze
,
self
).
__init__
()
self
.
d
=
d
def
forward
(
self
,
x
):
if
self
.
d
is
not
None
:
return
torch
.
squeeze
(
x
,
dim
=
self
.
d
)
else
:
return
torch
.
squeeze
(
x
)
self
.
run_test
(
Squeeze
(
d
),
(
x1
,
))
def
test_squeeze_without_no_op
(
self
):
x
=
torch
.
randn
(
2
,
1
,
4
)
self
.
squeeze_model_tests
(
1
,
x
)
def
test_squeeze_neg_without_no_op
(
self
):
x
=
torch
.
randn
(
2
,
1
,
4
)
self
.
squeeze_model_tests
(
-
2
,
x
)
def
test_squeeze_all_dims
(
self
):
x_squeeze
=
torch
.
randn
(
2
,
1
,
4
)
self
.
squeeze_model_tests
(
None
,
x_squeeze
)
def
test_squeeze_no_op
(
self
):
x_noop
=
torch
.
randn
(
2
,
1
,
4
)
self
.
squeeze_model_tests
(
2
,
x_noop
)
def
test_squeeze_runtime_dim
(
self
):
class
Squeeze
(
nn
.
Module
):
def
forward
(
self
,
d1
,
d2
):
t
=
torch
.
zeros
(
d1
[
0
],
d2
[
0
])
return
t
.
squeeze
(
0
)
d1
=
torch
.
tensor
([
1
])
d3
=
torch
.
tensor
([
3
])
d4
=
torch
.
tensor
([
4
])
self
.
run_test
(
Squeeze
(),
(
d1
,
d4
))
self
.
run_test
(
Squeeze
(),
(
d3
,
d4
))
def
test_squeeze
(
self
):
class
Squeeze
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
squeeze
(
x
,
dim
=-
2
)
x
=
torch
.
randn
(
2
,
1
,
4
)
self
.
run_test
(
Squeeze
(),
(
x
,
))
def
test_unsqueeze
(
self
):
class
Unsqueeze
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
unsqueeze
(
x
,
dim
=-
2
)
x
=
torch
.
randn
(
2
,
3
,
4
)
self
.
run_test
(
Unsqueeze
(),
(
x
,
))
def
test_maxpool_default_stride
(
self
):
class
MaxPoolModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
F
.
max_pool2d
(
x
,
2
)
model
=
MaxPoolModel
()
x
=
torch
.
randn
(
10
,
20
,
16
,
50
)
self
.
run_test
(
model
,
(
x
,
))
def
test_maxpool_adaptive
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
AdaptiveMaxPool1d
((
5
),
return_indices
=
False
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
20
,
16
,
50
,
requires_grad
=
True
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_maxpool_2d
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
MaxPool2d
(
5
,
padding
=
(
1
,
2
))
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
1
,
20
,
16
,
50
,
requires_grad
=
True
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_maxpool_1d_ceil
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
MaxPool1d
(
3
,
2
,
ceil_mode
=
True
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_maxpool_2d_ceil
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
MaxPool2d
(
3
,
2
,
ceil_mode
=
True
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
20
,
16
,
50
,
32
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_maxpool_3d_ceil
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
MaxPool3d
(
3
,
2
,
ceil_mode
=
True
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
20
,
16
,
50
,
44
,
31
)
self
.
run_test
(
MyModel
(),
(
x
,
))
@
unittest
.
skip
(
'jit error: Return value was annotated as having type Tensor but is actually of type Tuple[Tensor, Tensor]'
)
def
test_maxpool_with_indices
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
MaxPool1d
(
2
,
stride
=
1
,
return_indices
=
True
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_maxpool_dilation
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
MaxPool1d
(
2
,
stride
=
1
,
dilation
=
2
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_avgpool_default_stride
(
self
):
class
AvgPoolModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
F
.
avg_pool2d
(
x
,
2
)
model
=
AvgPoolModel
()
x
=
torch
.
randn
(
10
,
20
,
16
,
50
)
self
.
run_test
(
model
,
(
x
,
))
def
test_avgpool
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
AvgPool1d
(
2
,
stride
=
1
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
20
,
16
,
50
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_avgpool_1d_ceil
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
AvgPool1d
(
3
,
2
,
ceil_mode
=
True
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
1
,
1
,
7
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_avgpool_2d_ceil
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
AvgPool2d
(
3
,
2
,
ceil_mode
=
True
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
20
,
16
,
50
,
32
)
self
.
run_test
(
MyModel
(),
(
x
,
))
def
test_avgpool_3d_ceil
(
self
):
class
MyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m
=
nn
.
AvgPool3d
(
3
,
2
,
ceil_mode
=
True
)
def
forward
(
self
,
x
):
return
self
.
m
(
x
)
x
=
torch
.
randn
(
20
,
16
,
50
,
44
,
31
)
self
.
run_test
(
MyModel
(),
(
x
,
))
@
unittest
.
skip
(
'Unsupported op type aten::is_floating_point in if condition'
)
def
test_floating_point
(
self
):
class
FloatingPoint
(
nn
.
Module
):
def
forward
(
self
,
x
):
if
x
.
is_floating_point
():
return
x
.
new_zeros
(
x
.
shape
)
return
x
.
new_zeros
(
x
.
shape
)
x
=
torch
.
randn
(
2
,
3
,
4
)
self
.
run_test
(
FloatingPoint
(),
(
x
,
))
class
FloatingPoint
(
nn
.
Module
):
def
forward
(
self
,
x
):
if
x
.
size
(
0
)
>
1
:
a
=
x
+
2
if
a
.
is_floating_point
():
return
x
+
1
return
x
+
1
return
x
x
=
torch
.
randn
(
2
,
3
,
4
)
self
.
run_test
(
FloatingPoint
(),
(
x
,
))
# Operator rank mismatch between outputs of two branches for opsets below 11.
@
unittest
.
skip
(
'Unsupported op type aten::size in if condition'
)
def
test_floating_point_infer_dtype
(
self
):
class
FloatingPoint
(
nn
.
Module
):
def
forward
(
self
,
x
):
if
x
.
size
(
0
)
>
1
:
a
=
x
+
2
if
a
.
is_floating_point
():
return
x
.
new_zeros
(
x
.
shape
[
1
:])
return
x
.
new_zeros
(
x
.
shape
)
return
x
x
=
torch
.
randn
(
2
,
3
,
4
)
self
.
run_test
(
FloatingPoint
(),
(
x
,
))
class
FloatingPoint
(
nn
.
Module
):
def
forward
(
self
,
x
):
if
x
.
size
(
0
)
>
1
:
a
=
x
+
2
if
a
.
is_floating_point
():
return
x
+
1
return
x
return
x
x
=
torch
.
randn
(
2
,
3
,
4
).
to
(
torch
.
int32
)
self
.
run_test
(
FloatingPoint
(),
(
x
,
))
def
test_arithmetic
(
self
):
class
ArithmeticModule
(
nn
.
Module
):
def
forward
(
self
,
x
):
x
=
x
+
2
x
=
x
-
4
x
=
x
*
6
x
=
x
/
8
return
x
x
=
torch
.
randn
(
2
,
3
,
4
)
self
.
run_test
(
ArithmeticModule
(),
(
x
,
))
# In scripting the first transpose node do not carry shape and dtype info.
# The following test only works when onnx shape inference is enabled.
def
test_arithmetic_infer_dtype
(
self
):
class
ArithmeticModule
(
nn
.
Module
):
def
forward
(
self
,
x
):
x
=
x
.
t
()
x
=
x
+
2
x
=
x
-
4
x
=
x
*
6
x
=
x
/
8
return
x
x
=
torch
.
randn
(
2
,
3
)
self
.
run_test
(
ArithmeticModule
(),
(
x
,
))
@
unittest
.
skip
(
'tensor op type aten::to has more than one matched'
)
def
test_floor_div
(
self
):
class
FloorDivModule
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
return
x
//
3
,
x
//
2.
,
\
x
.
to
(
dtype
=
torch
.
float64
)
//
3
,
x
.
to
(
dtype
=
torch
.
float64
)
//
2.
,
\
x
.
to
(
dtype
=
torch
.
int64
)
//
3
,
x
.
to
(
dtype
=
torch
.
int64
)
//
2.
,
\
x
//
(
y
+
1.
).
to
(
dtype
=
torch
.
int64
),
x
//
y
,
\
x
.
to
(
dtype
=
torch
.
float64
)
//
y
.
to
(
dtype
=
torch
.
int64
),
x
.
to
(
dtype
=
torch
.
float64
)
//
y
.
to
(
dtype
=
torch
.
float64
),
\
x
.
to
(
dtype
=
torch
.
int64
)
//
y
.
to
(
dtype
=
torch
.
int64
),
x
.
to
(
dtype
=
torch
.
int64
)
//
y
x
=
torch
.
randn
(
2
,
3
,
4
)
y
=
torch
.
arange
(
1
,
2
*
3
*
4
+
1
).
reshape
(
2
,
3
,
4
)
self
.
run_test
(
FloorDivModule
(),
(
x
,
y
))
def
test_floor_div_script
(
self
):
class
FloorDivModule
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
return
x
//
3
,
x
//
2.
,
x
//
y
x
=
torch
.
randn
(
2
,
3
,
4
)
y
=
torch
.
randn
(
2
,
3
,
4
)
self
.
run_test
(
FloorDivModule
(),
(
x
,
y
))
def
test_floordiv
(
self
):
class
FloordivModule
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
new_zeros
(
x
.
size
(
2
)
//
x
.
size
(
1
))
x
=
torch
.
randn
(
2
,
3
,
4
)
self
.
run_test
(
FloordivModule
(),
(
x
,))
def
test_div
(
self
):
class
DivModule
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
return
torch
.
true_divide
(
x
,
y
)
x
=
torch
.
randn
(
2
,
3
,
4
).
to
(
torch
.
int
)
y
=
torch
.
arange
(
1
,
2
*
3
*
4
+
1
).
reshape
(
2
,
3
,
4
).
to
(
torch
.
int
)
self
.
run_test
(
DivModule
(),
(
x
,
y
))
self
.
run_test
(
DivModule
(),
(
x
.
float
(),
y
.
float
()))
# Note: div cannot (generally) be exported via scripting
# since its type promotion logic is dependent on knowing the scalar types
# of the input tensors. That is, the ONNX graph is dependent on the
# data type of the inputs. This makes it appropriate for tracing only.
def
test_div_promotion_trace
(
self
):
class
DivModule
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
return
torch
.
true_divide
(
x
,
y
)
x
=
torch
.
randn
(
2
,
3
,
4
).
to
(
torch
.
int
)
y
=
torch
.
arange
(
1
,
2
*
3
*
4
+
1
).
reshape
(
2
,
3
,
4
).
to
(
torch
.
int
)
prev_default
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
torch
.
float
)
self
.
run_test
(
DivModule
(),
(
x
,
y
))
torch
.
set_default_dtype
(
torch
.
double
)
self
.
run_test
(
DivModule
(),
(
x
,
y
))
torch
.
set_default_dtype
(
prev_default
)
# In scripting x, y do not carry shape and dtype info.
# The following test only works when onnx shape inference is enabled.
def
test_div_promotion_script
(
self
):
class
DivModule
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
# Add transpose to hide shape/type information
# Otherwise shape and type are still avaiable from input.
x
=
x
.
transpose
(
1
,
2
)
y
=
y
.
transpose
(
1
,
2
)
return
torch
.
true_divide
(
x
,
y
)
x
=
torch
.
randn
(
2
,
3
,
4
).
to
(
torch
.
int
)
y
=
torch
.
arange
(
1
,
2
*
3
*
4
+
1
).
reshape
(
2
,
3
,
4
).
to
(
torch
.
int
)
prev_default
=
torch
.
get_default_dtype
()
# 1. x,y are int, and output is float.
# This can be handled by the default case, where both are cast to float.
# It works even if type of x, y are unknown.
torch
.
set_default_dtype
(
torch
.
float
)
self
.
run_test
((
DivModule
()),
(
x
,
y
))
# 2. x,y are int, and output is double.
# This can be handled by the default case, where both are cast to double.
# It works even if type of x, y are unknown.
torch
.
set_default_dtype
(
torch
.
double
)
self
.
run_test
((
DivModule
()),
(
x
,
y
))
# 3. x is int, y is double, and output is double.
# This can only be handled when both type of x and y are known.
torch
.
set_default_dtype
(
prev_default
)
x
=
torch
.
randn
(
2
,
3
,
4
).
to
(
torch
.
int
)
y
=
torch
.
arange
(
1
,
2
*
3
*
4
+
1
).
reshape
(
2
,
3
,
4
).
to
(
torch
.
double
)
self
.
run_test
((
DivModule
()),
(
x
,
y
))
def
test_slice_trace
(
self
):
class
MyModule
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
[
0
:
1
]
x
=
torch
.
randn
(
3
)
self
.
run_test
(
MyModule
(),
(
x
,
))
def
test_slice_neg
(
self
):
class
NegSlice
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
[
-
1
:]
x
=
torch
.
randn
(
3
,
4
,
5
)
self
.
run_test
(
NegSlice
(),
(
x
,
))
def
test_slice_neg_large
(
self
):
class
NegSlice
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
[:,
:,
-
3
:
-
1
,
:,
-
1
]
x
=
torch
.
randn
(
3
,
4
,
5
,
6
,
7
)
self
.
run_test
(
NegSlice
(),
(
x
,
))
def
test_slice_neg_large_negone
(
self
):
class
NegSlice
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
[:,
:,
:,
:,
-
1
]
x
=
torch
.
randn
(
3
,
4
,
5
,
6
,
7
)
self
.
run_test
(
NegSlice
(),
(
x
,
))
@
unittest
.
skip
(
'strange torch script graph'
)
def
test_slice_with_input_index
(
self
):
class
InputIndexSlice
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
x
[:
y
.
size
(
0
),
0
,
:]
=
y
return
x
x
=
torch
.
zeros
((
56
,
6
,
256
))
y
=
torch
.
rand
((
22
,
256
))
self
.
run_test
(
InputIndexSlice
(),
(
x
,
y
))
@
unittest
.
skip
(
'Loop has not been supported yet!'
)
def
test_slice_dynamic
(
self
):
class
DynamicSliceExportMod
(
nn
.
Module
):
def
forward
(
self
,
x
):
results
=
[]
for
i
in
range
(
4
):
results
.
append
(
x
[:
x
.
size
(
0
)
-
i
,
i
:
x
.
size
(
2
),
i
:
3
])
return
results
x
=
torch
.
rand
(
5
,
5
,
5
)
y
=
torch
.
randn
(
6
,
7
,
8
)
self
.
run_test
(
DynamicSliceExportMod
(),
(
x
,
))
def
test_slice_dynamic_script
(
self
):
class
DynamicSliceModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
[
1
:
x
.
size
(
1
)]
x
=
torch
.
rand
(
1
,
2
)
self
.
run_test
(
DynamicSliceModel
(),
(
x
,
))
def
test_slice_dynamic_shape_script
(
self
):
class
DynamicSliceModel
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
new_zeros
(
x
.
shape
[
1
:
x
.
size
(
2
)])
x
=
torch
.
rand
(
1
,
2
,
3
,
4
)
self
.
run_test
(
DynamicSliceModel
(),
(
x
,
))
@
unittest
.
skip
(
'Loop has not been supported yet!'
)
def
test_slice_dynamic_to_end
(
self
):
class
DynamicSliceExportMod
(
nn
.
Module
):
def
forward
(
self
,
x
):
results
=
[]
for
i
in
range
(
4
):
results
.
append
(
x
[:,
i
:,
x
.
size
(
2
)
-
5
])
return
results
x
=
torch
.
rand
(
5
,
5
,
5
)
self
.
run_test
(
DynamicSliceExportMod
(),
(
x
,
))
def
test_square
(
self
):
class
Square
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
square
(
x
)
x
=
torch
.
randn
(
2
,
3
,
4
)
self
.
run_test
(
Square
(),
(
x
,
))
def
test_arange_dynamic
(
self
):
class
ArangeModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
torch
.
arange
(
input
.
shape
[
0
]),
\
torch
.
arange
(
12
),
\
torch
.
arange
(
start
=
input
.
shape
[
0
],
end
=
input
.
shape
[
0
]
+
5
)
x
=
torch
.
randn
(
5
,
3
,
2
)
y
=
torch
.
randn
(
8
,
3
,
2
)
self
.
run_test
(
ArangeModel
(),
(
x
,
))
@
unittest
.
skip
(
'mismatched aten::arange definition, does not support `out`'
)
def
test_dynamic_arange_out
(
self
):
class
ArangeOutModel
(
nn
.
Module
):
def
forward
(
self
,
end
):
out_t
=
torch
.
tensor
([
1
],
dtype
=
torch
.
int64
)
return
torch
.
arange
(
end
,
out
=
out_t
)
x
=
torch
.
tensor
(
8
)
self
.
run_test
(
ArangeOutModel
(),
(
x
,
))
@
unittest
.
skip
(
'mismatched aten::arange definition, does not support `out`'
)
def
test_dynamic_arange_start_out
(
self
):
class
ArangeStartOutModel
(
nn
.
Module
):
def
forward
(
self
,
start
,
end
):
out_t
=
torch
.
tensor
([
1
],
dtype
=
torch
.
int64
)
return
torch
.
arange
(
start
.
size
(
0
),
end
,
out
=
out_t
)
x
=
torch
.
randn
(
2
,
3
,
4
)
y
=
torch
.
tensor
(
8
)
self
.
run_test
(
ArangeStartOutModel
(),
(
x
,
y
))
def
test_arange
(
self
):
class
ArangeModel
(
nn
.
Module
):
def
forward
(
self
,
start
,
end
):
return
torch
.
arange
(
start
.
size
(
0
),
end
,
1.5
,
dtype
=
torch
.
int64
)
x
=
torch
.
randn
(
2
,
3
,
4
)
y
=
torch
.
tensor
(
8.5
,
dtype
=
torch
.
float
)
self
.
run_test
(
ArangeModel
(),
(
x
,
y
))
@
unittest
.
skip
(
'mismatched aten::arange definition, does not support `out`'
)
def
test_arange_out
(
self
):
class
ArangeOutModel
(
nn
.
Module
):
def
forward
(
self
,
end
):
out_t
=
torch
.
tensor
([
1
],
dtype
=
torch
.
float
)
return
torch
.
arange
(
end
,
out
=
out_t
)
x
=
torch
.
tensor
(
8.5
,
dtype
=
torch
.
float
)
self
.
run_test
(
ArangeOutModel
(),
(
x
,
))
@
unittest
.
skip
(
'mismatched aten::arange definition, does not support `out`'
)
def
test_arange_start_out
(
self
):
class
ArangeStartOutModel
(
nn
.
Module
):
def
forward
(
self
,
start
,
end
):
out_t
=
torch
.
tensor
([
1
],
dtype
=
torch
.
float
)
return
torch
.
arange
(
start
.
size
(
0
),
end
,
out
=
out_t
)
x
=
torch
.
randn
(
2
,
3
,
4
)
y
=
torch
.
tensor
(
8.5
,
dtype
=
torch
.
float
)
self
.
run_test
(
ArangeStartOutModel
(),
(
x
,
y
))
def
test_arange_no_type
(
self
):
class
ArangeModel
(
nn
.
Module
):
def
forward
(
self
,
end
):
return
torch
.
arange
(
end
),
\
torch
.
arange
(
0
,
end
)
x
=
torch
.
tensor
(
6.2
,
dtype
=
torch
.
float
)
self
.
run_test
(
ArangeModel
(),
(
x
,
))
def
test_size
(
self
):
class
SizeModel
(
nn
.
Module
):
def
forward
(
self
,
input
):
return
torch
.
arange
(
input
.
size
(
0
)),
torch
.
arange
(
input
.
size
(
-
1
)),
torch
.
ones
(
input
.
shape
)
x
=
torch
.
randn
(
5
,
3
,
2
)
self
.
run_test
(
SizeModel
(),
(
x
,
))
def
test_size2
(
self
):
class
SizeModel
(
nn
.
Module
):
def
__init__
(
self
,
a
,
b
):
super
().
__init__
()
self
.
a
=
a
self
.
b
=
b
def
forward
(
self
,
input
):
if
self
.
a
<
self
.
b
:
return
torch
.
arange
(
input
.
size
(
0
)),
torch
.
arange
(
input
.
size
(
-
1
)),
torch
.
ones
(
input
.
shape
)
x
=
torch
.
randn
(
5
,
3
,
2
)
self
.
run_test
(
SizeModel
(
10
,
5
),
(
x
,
))
\ No newline at end of file
test/ut/retiarii/test_highlevel_apis.py
View file @
58d5c2fa
...
@@ -167,6 +167,7 @@ class TestHighLevelAPI(unittest.TestCase):
...
@@ -167,6 +167,7 @@ class TestHighLevelAPI(unittest.TestCase):
mutator
=
mutators
[
0
].
bind_sampler
(
EnuemrateSampler
())
mutator
=
mutators
[
0
].
bind_sampler
(
EnuemrateSampler
())
model1
=
mutator
.
apply
(
model
)
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment