Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
58d5c2fa
Unverified
Commit
58d5c2fa
authored
Feb 22, 2021
by
QuanluZhang
Committed by
GitHub
Feb 22, 2021
Browse files
[retiarii] refactor of pytorch operators (#3365)
parent
59521d33
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
4020 additions
and
310 deletions
+4020
-310
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+35
-6
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+244
-120
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+0
-26
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+1
-1
nni/retiarii/operation.py
nni/retiarii/operation.py
+47
-52
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+407
-29
nni/retiarii/utils.py
nni/retiarii/utils.py
+8
-0
pipelines/fast-test.yml
pipelines/fast-test.yml
+3
-1
test/ut/retiarii/inject_nn.py
test/ut/retiarii/inject_nn.py
+280
-0
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+87
-75
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+283
-0
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+0
-0
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+1390
-0
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+1234
-0
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+1
-0
No files found.
nni/retiarii/codegen/pytorch.py
View file @
58d5c2fa
import
logging
import
logging
from
typing
import
List
from
typing
import
List
,
Tuple
,
Any
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
...
@@ -32,9 +32,26 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
...
@@ -32,9 +32,26 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
raise
IllegalGraphError
(
node
.
graph
,
'Node {} has bad inputs'
.
format
(
node
.
name
))
raise
IllegalGraphError
(
node
.
graph
,
'Node {} has bad inputs'
.
format
(
node
.
name
))
def
_format_inputs
(
node
:
Node
)
->
List
[
str
]:
def
_format_inputs
(
node
:
Node
)
->
Tuple
[
List
[
str
],
List
[
Any
]]:
"""
Format the inputs of a given node
Parameters
----------
node : Node
a graph node, get and format its inputs
Returns
-------
list
the list of input names
list
the list of input values, if an input is simple type, record its value,
otherwise the value is None
"""
edges
=
_sorted_incoming_edges
(
node
)
edges
=
_sorted_incoming_edges
(
node
)
inputs
=
[]
inputs
=
[]
inputs_value
=
[]
for
edge
in
edges
:
for
edge
in
edges
:
if
edge
.
head
.
name
==
'_inputs'
:
if
edge
.
head
.
name
==
'_inputs'
:
assert
isinstance
(
edge
.
head_slot
,
int
)
assert
isinstance
(
edge
.
head_slot
,
int
)
...
@@ -44,14 +61,21 @@ def _format_inputs(node: Node) -> List[str]:
...
@@ -44,14 +61,21 @@ def _format_inputs(node: Node) -> List[str]:
else
:
else
:
# when input has no name, e.g., forward(*_inputs)
# when input has no name, e.g., forward(*_inputs)
inputs
.
append
(
'_inputs[{}]'
.
format
(
edge
.
head_slot
))
inputs
.
append
(
'_inputs[{}]'
.
format
(
edge
.
head_slot
))
inputs_value
.
append
(
None
)
else
:
else
:
if
edge
.
head_slot
is
None
:
if
edge
.
head_slot
is
None
:
# when the input comes from a single-output operator
# when the input comes from a single-output operator
inputs
.
append
(
'{}'
.
format
(
edge
.
head
.
name
))
inputs
.
append
(
'{}'
.
format
(
edge
.
head
.
name
))
if
edge
.
head
.
operation
.
type
in
(
'prim::Constant'
,
'prim::GetAttr'
)
and
\
'value'
in
edge
.
head
.
operation
.
parameters
:
inputs_value
.
append
(
edge
.
head
.
operation
.
parameters
[
'value'
])
else
:
inputs_value
.
append
(
None
)
else
:
else
:
# when the input comes from a multi-output operator: needs to know which one it comes from
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs
.
append
(
'{}[{}]'
.
format
(
edge
.
head
.
name
,
edge
.
head_slot
))
inputs
.
append
(
'{}[{}]'
.
format
(
edge
.
head
.
name
,
edge
.
head_slot
))
return
inputs
inputs_value
.
append
(
None
)
return
inputs
,
inputs_value
def
_remove_prefix
(
names
,
graph_name
):
def
_remove_prefix
(
names
,
graph_name
):
...
@@ -80,6 +104,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
...
@@ -80,6 +104,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
node_codes
=
[]
node_codes
=
[]
for
node
in
nodes
:
for
node
in
nodes
:
if
node
.
operation
:
if
node
.
operation
:
if
node
.
operation
.
type
==
'shared'
:
continue
pkg_name
=
node
.
operation
.
get_import_pkg
()
pkg_name
=
node
.
operation
.
get_import_pkg
()
if
pkg_name
is
not
None
:
if
pkg_name
is
not
None
:
import_pkgs
.
add
(
pkg_name
)
import_pkgs
.
add
(
pkg_name
)
...
@@ -101,12 +127,15 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
...
@@ -101,12 +127,15 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
sorted_nodes
=
graph
.
topo_sort
()
sorted_nodes
=
graph
.
topo_sort
()
for
node
in
sorted_nodes
:
for
node
in
sorted_nodes
:
if
node
.
operation
:
if
node
.
operation
:
inputs
=
_format_inputs
(
node
)
inputs
,
inputs_value
=
_format_inputs
(
node
)
inputs
=
_remove_prefix
(
inputs
,
graph_name
)
inputs
=
_remove_prefix
(
inputs
,
graph_name
)
node_name
=
_remove_prefix
(
node
.
name
,
graph_name
)
node_name
=
_remove_prefix
(
node
.
name
,
graph_name
)
edge_codes
.
append
(
node
.
operation
.
to_forward_code
(
node_name
,
node_name
,
inputs
))
submodule_name
=
node_name
if
node
.
operation
.
type
==
'shared'
:
submodule_name
=
_remove_prefix
(
node
.
operation
.
parameters
[
'reference'
],
graph_name
)
edge_codes
.
append
(
node
.
operation
.
to_forward_code
(
submodule_name
,
node_name
,
inputs
,
inputs_value
))
output_names
=
_format_inputs
(
graph
.
output_node
)
output_names
,
_
=
_format_inputs
(
graph
.
output_node
)
output_names
=
_remove_prefix
(
output_names
,
graph_name
)
output_names
=
_remove_prefix
(
output_names
,
graph_name
)
if
not
output_names
:
if
not
output_names
:
raise
RuntimeError
(
'"forward" function should have return value(s): {}, {}, {}'
.
format
(
output_names
,
graph_name
,
graph
.
output_node
))
raise
RuntimeError
(
'"forward" function should have return value(s): {}, {}, {}'
.
format
(
output_names
,
graph_name
,
graph
.
output_node
))
...
...
nni/retiarii/converter/graph_gen.py
View file @
58d5c2fa
This diff is collapsed.
Click to expand it.
nni/retiarii/converter/op_types.py
View file @
58d5c2fa
...
@@ -9,34 +9,8 @@ class OpTypeName(str, Enum):
...
@@ -9,34 +9,8 @@ class OpTypeName(str, Enum):
"""
"""
Attr
=
'Attr'
Attr
=
'Attr'
Constant
=
'Constant'
Constant
=
'Constant'
ListConstruct
=
'ListConstruct'
TupleConstruct
=
'TupleConstruct'
LayerChoice
=
'LayerChoice'
LayerChoice
=
'LayerChoice'
InputChoice
=
'InputChoice'
InputChoice
=
'InputChoice'
ValueChoice
=
'ValueChoice'
ValueChoice
=
'ValueChoice'
Placeholder
=
'Placeholder'
Placeholder
=
'Placeholder'
MergedSlice
=
'MergedSlice'
MergedSlice
=
'MergedSlice'
# deal with aten op
BasicOpsPT
=
{
'aten::mean'
:
'Mean'
,
'aten::relu'
:
'Relu'
,
'aten::add'
:
'Add'
,
'aten::__getitem__'
:
'getitem'
,
'aten::append'
:
'Append'
,
'aten::len'
:
'Len'
,
'aten::slice'
:
'Slice'
,
'aten::cat'
:
'Cat'
,
'aten::size'
:
'Size'
,
'aten::view'
:
'View'
,
'aten::reshape'
:
'Reshape'
,
'aten::eq'
:
'Eq'
,
'aten::Bool'
:
'Bool'
,
'aten::empty'
:
'Empty'
,
'aten::zeros'
:
'Zeros'
,
'aten::chunk'
:
'Chunk'
,
'aten::add_'
:
'Add_'
# %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
BasicOpsTF
=
{}
nni/retiarii/nn/pytorch/mutator.py
View file @
58d5c2fa
...
@@ -45,7 +45,7 @@ class ValueChoiceMutator(Mutator):
...
@@ -45,7 +45,7 @@ class ValueChoiceMutator(Mutator):
chosen
=
self
.
choice
(
self
.
candidates
)
chosen
=
self
.
choice
(
self
.
candidates
)
for
node
in
self
.
nodes
:
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
target
=
model
.
get_node_by_name
(
node
.
name
)
target
.
update_operation
(
'prim::Constant'
,
{
'value'
:
chosen
})
target
.
update_operation
(
'prim::Constant'
,
{
'type'
:
type
(
chosen
).
__name__
,
'value'
:
chosen
})
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
...
...
nni/retiarii/operation.py
View file @
58d5c2fa
...
@@ -83,6 +83,31 @@ class Operation:
...
@@ -83,6 +83,31 @@ class Operation:
class
PyTorchOperation
(
Operation
):
class
PyTorchOperation
(
Operation
):
@
classmethod
def
_find_subclass
(
cls
,
subclass_name
):
if
cls
.
to_class_name
(
subclass_name
)
is
not
None
:
subclass_name
=
'ModuleOperator'
if
cls
.
is_functional
(
subclass_name
):
subclass_name
=
'FunctionalOperator'
for
subclass
in
cls
.
__subclasses__
():
if
hasattr
(
subclass
,
'_ori_type_name'
)
and
\
subclass_name
in
subclass
.
_ori_type_name
:
return
subclass
return
cls
@
classmethod
def
to_class_name
(
cls
,
type_name
)
->
str
:
if
type_name
.
startswith
(
'__torch__.'
):
return
type_name
[
len
(
'__torch__.'
):]
elif
type_name
.
startswith
(
'__mutated__.'
):
return
type_name
[
len
(
'__mutated__.'
):]
else
:
return
None
@
classmethod
def
is_functional
(
cls
,
type_name
)
->
bool
:
return
type_name
.
startswith
(
'Function.'
)
def
_to_class_name
(
self
)
->
str
:
def
_to_class_name
(
self
)
->
str
:
if
self
.
type
.
startswith
(
'__torch__.'
):
if
self
.
type
.
startswith
(
'__torch__.'
):
return
self
.
type
[
len
(
'__torch__.'
):]
return
self
.
type
[
len
(
'__torch__.'
):]
...
@@ -106,59 +131,27 @@ class PyTorchOperation(Operation):
...
@@ -106,59 +131,27 @@ class PyTorchOperation(Operation):
return
f
'self.
{
field
}
=
{
self
.
_to_class_name
()
}
(
{
kw_params
}
)'
return
f
'self.
{
field
}
=
{
self
.
_to_class_name
()
}
(
{
kw_params
}
)'
return
None
return
None
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
from
.converter.op_types
import
OpTypeName
"""
if
self
.
_to_class_name
()
is
not
None
:
Parameters
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
----------
elif
self
.
type
.
startswith
(
'Function.'
):
field : str
func_name
=
self
.
type
[
len
(
'Function.'
):]
the name of member submodule
return
f
'
{
output
}
= F.
{
func_name
}
(
{
", "
.
join
(
inputs
)
}
)'
output : str
elif
self
.
type
==
'prim::Constant'
:
the output name (lvalue) of this line of code
if
self
.
parameters
:
inputs : List[str]
value
=
self
.
parameters
[
'value'
]
variables used in this line of code
else
:
inputs_value : List[Any]
value
=
None
some variables are actually constant, their real values are recorded in ```inputs_value```.
return
f
'
{
output
}
=
{
value
}
'
if not constant, we simply put None at the corresponding index
elif
self
.
type
==
'prim::ListConstruct'
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
Returns
elif
self
.
type
==
'prim::TupleConstruct'
:
-------
return
f
'
{
output
}
= (
{
", "
.
join
(
inputs
)
}
)'
str
elif
self
.
type
==
'prim::GetAttr'
:
generated code line
return
f
"
{
output
}
=
{
self
.
parameters
[
'input'
]
}
.
{
self
.
parameters
[
'name'
]
}
"
"""
elif
self
.
type
==
'aten::mean'
:
if
self
.
type
==
'aten::slice'
:
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
elif
self
.
type
==
'aten::__getitem__'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
elif
self
.
type
==
'aten::append'
:
assert
len
(
inputs
)
==
2
return
f
'_,
{
output
}
=
{
inputs
[
0
]
}
.append(
{
inputs
[
1
]
}
),
{
inputs
[
0
]
}
'
elif
self
.
type
==
'aten::cat'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::add'
:
return
f
'
{
output
}
= '
+
' + '
.
join
(
inputs
)
elif
self
.
type
==
OpTypeName
.
MergedSlice
:
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
slices
=
[]
dim
=
int
((
len
(
inputs
)
-
1
)
/
4
)
for
i
in
range
(
dim
):
slices
.
append
(
f
'
{
inputs
[
i
*
4
+
2
]
}
:
{
inputs
[
i
*
4
+
3
]
}
:
{
inputs
[
i
*
4
+
4
]
}
'
)
slice_str
=
','
.
join
(
slices
)
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
slice_str
}
]'
elif
self
.
type
==
'aten::size'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.size(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::view'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::reshape'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.reshape(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::slice'
:
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
elif
self
.
type
==
'aten::Bool'
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
else
:
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
...
@@ -212,6 +205,8 @@ class Cell(PyTorchOperation):
...
@@ -212,6 +205,8 @@ class Cell(PyTorchOperation):
# TODO: ugly, think about how to refactor this part
# TODO: ugly, think about how to refactor this part
return
_convert_name
(
self
.
cell_name
)
return
_convert_name
(
self
.
cell_name
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
class
_IOPseudoOperation
(
Operation
):
class
_IOPseudoOperation
(
Operation
):
"""
"""
...
...
nni/retiarii/operation_def/torch_op_def.py
View file @
58d5c2fa
This diff is collapsed.
Click to expand it.
nni/retiarii/utils.py
View file @
58d5c2fa
...
@@ -162,6 +162,14 @@ def _get_module_name(cls):
...
@@ -162,6 +162,14 @@ def _get_module_name(cls):
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
module_name
=
main_file_path
.
stem
module_name
=
main_file_path
.
stem
break
break
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if
f
'
{
cls
.
__module__
}
.
{
cls
.
__name__
}
'
==
'torch.nn.modules.rnn.LSTM'
:
module_name
=
cls
.
__module__
return
module_name
return
module_name
...
...
pipelines/fast-test.yml
View file @
58d5c2fa
...
@@ -250,7 +250,9 @@ stages:
...
@@ -250,7 +250,9 @@ stages:
-
script
:
|
-
script
:
|
cd test
cd test
python -m pytest ut
python -m pytest ut --ignore=ut/retiarii/test_convert_basic.py \
--ignore=ut/retiarii/test_convert_operators.py \
--ignore=ut/retiarii/test_convert_pytorch.py
displayName
:
Python unit test
displayName
:
Python unit test
-
script
:
|
-
script
:
|
...
...
test/ut/retiarii/inject_nn.py
0 → 100644
View file @
58d5c2fa
import
inspect
import
logging
import
torch
import
torch.nn
as
nn
from
nni.retiarii.utils
import
add_record
,
del_record
,
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
def
wrap_module
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
original_class
.
bak_init_for_inject
=
orig_init
if
hasattr
(
original_class
,
'__del__'
):
orig_del
=
original_class
.
__del__
original_class
.
bak_del_for_inject
=
orig_del
else
:
orig_del
=
None
original_class
.
bak_del_for_inject
=
None
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
def
__del__
(
self
):
del_record
(
id
(
self
))
if
orig_del
is
not
None
:
orig_del
(
self
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__del__
=
__del__
return
original_class
def
unwrap_module
(
wrapped_class
):
if
hasattr
(
wrapped_class
,
'bak_init_for_inject'
):
wrapped_class
.
__init__
=
wrapped_class
.
bak_init_for_inject
delattr
(
wrapped_class
,
'bak_init_for_inject'
)
if
hasattr
(
wrapped_class
,
'bak_del_for_inject'
):
if
wrapped_class
.
bak_del_for_inject
is
not
None
:
wrapped_class
.
__del__
=
wrapped_class
.
bak_del_for_inject
delattr
(
wrapped_class
,
'bak_del_for_inject'
)
return
None
def
remove_inject_pytorch_nn
():
Identity
=
unwrap_module
(
nn
.
Identity
)
Linear
=
unwrap_module
(
nn
.
Linear
)
Conv1d
=
unwrap_module
(
nn
.
Conv1d
)
Conv2d
=
unwrap_module
(
nn
.
Conv2d
)
Conv3d
=
unwrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
unwrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
unwrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
unwrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
unwrap_module
(
nn
.
Threshold
)
ReLU
=
unwrap_module
(
nn
.
ReLU
)
Hardtanh
=
unwrap_module
(
nn
.
Hardtanh
)
ReLU6
=
unwrap_module
(
nn
.
ReLU6
)
Sigmoid
=
unwrap_module
(
nn
.
Sigmoid
)
Tanh
=
unwrap_module
(
nn
.
Tanh
)
Softmax
=
unwrap_module
(
nn
.
Softmax
)
Softmax2d
=
unwrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
unwrap_module
(
nn
.
LogSoftmax
)
ELU
=
unwrap_module
(
nn
.
ELU
)
SELU
=
unwrap_module
(
nn
.
SELU
)
CELU
=
unwrap_module
(
nn
.
CELU
)
GLU
=
unwrap_module
(
nn
.
GLU
)
GELU
=
unwrap_module
(
nn
.
GELU
)
Hardshrink
=
unwrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
unwrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
unwrap_module
(
nn
.
LogSigmoid
)
Softplus
=
unwrap_module
(
nn
.
Softplus
)
Softshrink
=
unwrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
unwrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
unwrap_module
(
nn
.
PReLU
)
Softsign
=
unwrap_module
(
nn
.
Softsign
)
Softmin
=
unwrap_module
(
nn
.
Softmin
)
Tanhshrink
=
unwrap_module
(
nn
.
Tanhshrink
)
RReLU
=
unwrap_module
(
nn
.
RReLU
)
AvgPool1d
=
unwrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
unwrap_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
unwrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
unwrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
unwrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
unwrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
unwrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
unwrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
unwrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
unwrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
unwrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
unwrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
unwrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
unwrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
unwrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
unwrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
unwrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
unwrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
unwrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
unwrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
unwrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
unwrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
unwrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
unwrap_module
(
nn
.
Dropout
)
Dropout2d
=
unwrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
unwrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
unwrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
unwrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
unwrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
unwrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
unwrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
unwrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
unwrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
unwrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
unwrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
unwrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
unwrap_module
(
nn
.
RNNBase
)
RNN
=
unwrap_module
(
nn
.
RNN
)
LSTM
=
unwrap_module
(
nn
.
LSTM
)
GRU
=
unwrap_module
(
nn
.
GRU
)
RNNCellBase
=
unwrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
unwrap_module
(
nn
.
RNNCell
)
LSTMCell
=
unwrap_module
(
nn
.
LSTMCell
)
GRUCell
=
unwrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
unwrap_module
(
nn
.
PixelShuffle
)
Upsample
=
unwrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
unwrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
unwrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
unwrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
unwrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
unwrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
unwrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
unwrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
unwrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
unwrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
unwrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
unwrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
unwrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
unwrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
unwrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
unwrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
unwrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
unwrap_module
(
nn
.
Unfold
)
Fold
=
unwrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
unwrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
unwrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
unwrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
unwrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
unwrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
unwrap_module
(
nn
.
Transformer
)
Flatten
=
unwrap_module
(
nn
.
Flatten
)
Hardsigmoid
=
unwrap_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
unwrap_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
unwrap_module
(
nn
.
SiLU
)
Unflatten
=
unwrap_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
unwrap_module
(
nn
.
TripletMarginWithDistanceLoss
)
def
inject_pytorch_nn
():
Identity
=
wrap_module
(
nn
.
Identity
)
Linear
=
wrap_module
(
nn
.
Linear
)
Conv1d
=
wrap_module
(
nn
.
Conv1d
)
Conv2d
=
wrap_module
(
nn
.
Conv2d
)
Conv3d
=
wrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
wrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
wrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
wrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
wrap_module
(
nn
.
Threshold
)
ReLU
=
wrap_module
(
nn
.
ReLU
)
Hardtanh
=
wrap_module
(
nn
.
Hardtanh
)
ReLU6
=
wrap_module
(
nn
.
ReLU6
)
Sigmoid
=
wrap_module
(
nn
.
Sigmoid
)
Tanh
=
wrap_module
(
nn
.
Tanh
)
Softmax
=
wrap_module
(
nn
.
Softmax
)
Softmax2d
=
wrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
wrap_module
(
nn
.
LogSoftmax
)
ELU
=
wrap_module
(
nn
.
ELU
)
SELU
=
wrap_module
(
nn
.
SELU
)
CELU
=
wrap_module
(
nn
.
CELU
)
GLU
=
wrap_module
(
nn
.
GLU
)
GELU
=
wrap_module
(
nn
.
GELU
)
Hardshrink
=
wrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
wrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
wrap_module
(
nn
.
LogSigmoid
)
Softplus
=
wrap_module
(
nn
.
Softplus
)
Softshrink
=
wrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
wrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
wrap_module
(
nn
.
PReLU
)
Softsign
=
wrap_module
(
nn
.
Softsign
)
Softmin
=
wrap_module
(
nn
.
Softmin
)
Tanhshrink
=
wrap_module
(
nn
.
Tanhshrink
)
RReLU
=
wrap_module
(
nn
.
RReLU
)
AvgPool1d
=
wrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
wrap_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
wrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
wrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
wrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
wrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
wrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
wrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
wrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
wrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
wrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
wrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
wrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
wrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
wrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
wrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
wrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
wrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
wrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
wrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
wrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
wrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
wrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
wrap_module
(
nn
.
Dropout
)
Dropout2d
=
wrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
wrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
wrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
wrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
wrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
wrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
wrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
wrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
wrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
wrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
wrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
wrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
wrap_module
(
nn
.
RNNBase
)
RNN
=
wrap_module
(
nn
.
RNN
)
LSTM
=
wrap_module
(
nn
.
LSTM
)
GRU
=
wrap_module
(
nn
.
GRU
)
RNNCellBase
=
wrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
wrap_module
(
nn
.
RNNCell
)
LSTMCell
=
wrap_module
(
nn
.
LSTMCell
)
GRUCell
=
wrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
wrap_module
(
nn
.
PixelShuffle
)
Upsample
=
wrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
wrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
wrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
wrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
wrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
wrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
wrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
wrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
wrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
wrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
wrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
wrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
wrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
wrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
wrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
wrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
wrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
wrap_module
(
nn
.
Unfold
)
Fold
=
wrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
wrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
wrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
wrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
wrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
wrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
wrap_module
(
nn
.
Transformer
)
Flatten
=
wrap_module
(
nn
.
Flatten
)
Hardsigmoid
=
wrap_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
wrap_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
wrap_module
(
nn
.
SiLU
)
Unflatten
=
wrap_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
wrap_module
(
nn
.
TripletMarginWithDistanceLoss
)
test/ut/retiarii/test_convert.py
View file @
58d5c2fa
...
@@ -35,16 +35,29 @@ class MnistNet(nn.Module):
...
@@ -35,16 +35,29 @@ class MnistNet(nn.Module):
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
# NOTE: blackbox module cannot be placed within class or function
@
blackbox_module
class
Linear
(
nn
.
Module
):
def
__init__
(
self
,
d_embed
,
d_proj
):
super
().
__init__
()
self
.
linear
=
nn
.
Linear
(
d_embed
,
d_proj
)
def
forward
(
self
,
input
):
if
len
(
input
.
size
())
<=
2
:
return
self
.
linear
(
input
)
size
=
input
.
size
()[:
2
]
out
=
self
.
linear
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
TestConvert
(
unittest
.
TestCase
):
class
TestConvert
(
unittest
.
TestCase
):
@
staticmethod
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
k
,
v
in
expected_format
.
items
():
for
cv
in
current_values
:
for
idx
,
cv
in
enumerate
(
current_values
)
:
if
cv
.
shape
==
v
.
shape
:
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
result
[
k
]
=
cv
current_values
.
remove
(
cv
)
current_values
.
pop
(
idx
)
break
break
return
result
return
result
...
@@ -53,6 +66,9 @@ class TestConvert(unittest.TestCase):
...
@@ -53,6 +66,9 @@ class TestConvert(unittest.TestCase):
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
model_code
=
model_to_pytorch_script
(
model_ir
)
from
.inject_nn
import
remove_inject_pytorch_nn
remove_inject_pytorch_nn
()
exec_vars
=
{}
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_model
=
exec_vars
[
'converted_model'
]
...
@@ -134,18 +150,17 @@ class TestConvert(unittest.TestCase):
...
@@ -134,18 +150,17 @@ class TestConvert(unittest.TestCase):
model
=
DCGANGenerator
(
nz
,
ngf
,
nc
)
model
=
DCGANGenerator
(
nz
,
ngf
,
nc
)
self
.
checkExportImport
(
model
,
input
)
self
.
checkExportImport
(
model
,
input
)
@
unittest
.
skip
(
'this test has a if condition that needs to be handle'
)
# FIXME
def
test_neural_style
(
self
):
def
test_neural_style
(
self
):
class
TransformerNet
(
torch
.
nn
.
Module
):
class
TransformerNet
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
TransformerNet
,
self
).
__init__
()
super
(
TransformerNet
,
self
).
__init__
()
# Initial convolution layers
# Initial convolution layers
self
.
conv1
=
ConvLayer
(
3
,
32
,
kernel_size
=
9
,
stride
=
1
)
self
.
conv1
=
ConvLayer
(
3
,
32
,
kernel_size
=
9
,
stride
=
1
)
self
.
in1
=
torch
.
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
in1
=
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
conv2
=
ConvLayer
(
32
,
64
,
kernel_size
=
3
,
stride
=
2
)
self
.
conv2
=
ConvLayer
(
32
,
64
,
kernel_size
=
3
,
stride
=
2
)
self
.
in2
=
torch
.
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
in2
=
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
conv3
=
ConvLayer
(
64
,
128
,
kernel_size
=
3
,
stride
=
2
)
self
.
conv3
=
ConvLayer
(
64
,
128
,
kernel_size
=
3
,
stride
=
2
)
self
.
in3
=
torch
.
nn
.
InstanceNorm2d
(
128
,
affine
=
True
)
self
.
in3
=
nn
.
InstanceNorm2d
(
128
,
affine
=
True
)
# Residual layers
# Residual layers
self
.
res1
=
ResidualBlock
(
128
)
self
.
res1
=
ResidualBlock
(
128
)
self
.
res2
=
ResidualBlock
(
128
)
self
.
res2
=
ResidualBlock
(
128
)
...
@@ -154,12 +169,12 @@ class TestConvert(unittest.TestCase):
...
@@ -154,12 +169,12 @@ class TestConvert(unittest.TestCase):
self
.
res5
=
ResidualBlock
(
128
)
self
.
res5
=
ResidualBlock
(
128
)
# Upsampling Layers
# Upsampling Layers
self
.
deconv1
=
UpsampleConvLayer
(
128
,
64
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
deconv1
=
UpsampleConvLayer
(
128
,
64
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
in4
=
torch
.
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
in4
=
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
deconv2
=
UpsampleConvLayer
(
64
,
32
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
deconv2
=
UpsampleConvLayer
(
64
,
32
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
in5
=
torch
.
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
in5
=
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
deconv3
=
ConvLayer
(
32
,
3
,
kernel_size
=
9
,
stride
=
1
)
self
.
deconv3
=
ConvLayer
(
32
,
3
,
kernel_size
=
9
,
stride
=
1
)
# Non-linearities
# Non-linearities
self
.
relu
=
torch
.
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
X
):
def
forward
(
self
,
X
):
y
=
self
.
relu
(
self
.
in1
(
self
.
conv1
(
X
)))
y
=
self
.
relu
(
self
.
in1
(
self
.
conv1
(
X
)))
...
@@ -175,19 +190,19 @@ class TestConvert(unittest.TestCase):
...
@@ -175,19 +190,19 @@ class TestConvert(unittest.TestCase):
y
=
self
.
deconv3
(
y
)
y
=
self
.
deconv3
(
y
)
return
y
return
y
class
ConvLayer
(
torch
.
nn
.
Module
):
class
ConvLayer
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
):
super
(
ConvLayer
,
self
).
__init__
()
super
(
ConvLayer
,
self
).
__init__
()
reflection_padding
=
kernel_size
//
2
reflection_padding
=
kernel_size
//
2
self
.
reflection_pad
=
torch
.
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
reflection_pad
=
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
self
.
conv2d
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
out
=
self
.
reflection_pad
(
x
)
out
=
self
.
reflection_pad
(
x
)
out
=
self
.
conv2d
(
out
)
out
=
self
.
conv2d
(
out
)
return
out
return
out
class
ResidualBlock
(
torch
.
nn
.
Module
):
class
ResidualBlock
(
nn
.
Module
):
"""ResidualBlock
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
...
@@ -196,10 +211,10 @@ class TestConvert(unittest.TestCase):
...
@@ -196,10 +211,10 @@ class TestConvert(unittest.TestCase):
def
__init__
(
self
,
channels
):
def
__init__
(
self
,
channels
):
super
(
ResidualBlock
,
self
).
__init__
()
super
(
ResidualBlock
,
self
).
__init__
()
self
.
conv1
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
conv1
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
in1
=
torch
.
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
in1
=
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
conv2
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
conv2
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
in2
=
torch
.
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
in2
=
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
relu
=
torch
.
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
residual
=
x
residual
=
x
...
@@ -208,7 +223,7 @@ class TestConvert(unittest.TestCase):
...
@@ -208,7 +223,7 @@ class TestConvert(unittest.TestCase):
out
=
out
+
residual
out
=
out
+
residual
return
out
return
out
class
UpsampleConvLayer
(
torch
.
nn
.
Module
):
class
UpsampleConvLayer
(
nn
.
Module
):
"""UpsampleConvLayer
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
compared to ConvTranspose2d.
...
@@ -219,10 +234,10 @@ class TestConvert(unittest.TestCase):
...
@@ -219,10 +234,10 @@ class TestConvert(unittest.TestCase):
super
(
UpsampleConvLayer
,
self
).
__init__
()
super
(
UpsampleConvLayer
,
self
).
__init__
()
self
.
upsample
=
upsample
self
.
upsample
=
upsample
if
upsample
:
if
upsample
:
self
.
upsample_layer
=
torch
.
nn
.
Upsample
(
mode
=
'nearest'
,
scale_factor
=
upsample
)
self
.
upsample_layer
=
nn
.
Upsample
(
mode
=
'nearest'
,
scale_factor
=
upsample
)
reflection_padding
=
kernel_size
//
2
reflection_padding
=
kernel_size
//
2
self
.
reflection_pad
=
torch
.
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
reflection_pad
=
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
self
.
conv2d
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x_in
=
x
x_in
=
x
...
@@ -254,50 +269,40 @@ class TestConvert(unittest.TestCase):
...
@@ -254,50 +269,40 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
Policy
(),
(
torch
.
rand
(
1
,
4
),))
self
.
checkExportImport
(
Policy
(),
(
torch
.
rand
(
1
,
4
),))
@
unittest
.
skip
(
'Replaced init error.'
)
# FIXME
def
test_snli
(
self
):
def
test_snli
(
self
):
class
Bottle
(
nn
.
Module
):
def
forward
(
self
,
input
):
if
len
(
input
.
size
())
<=
2
:
return
super
(
Bottle
,
self
).
forward
(
input
)
size
=
input
.
size
()[:
2
]
out
=
super
(
Bottle
,
self
).
forward
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
Linear
(
Bottle
,
nn
.
Linear
):
pass
class
Encoder
(
nn
.
Module
):
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
Encoder
,
self
).
__init__
()
super
(
Encoder
,
self
).
__init__
()
self
.
config
=
config
#self.config = config
input_size
=
config
.
d_proj
if
config
.
projection
else
config
.
d_embed
input_size
=
config
[
"d_proj"
]
if
config
[
"projection"
]
else
config
[
"d_embed"
]
dropout
=
0
if
config
.
n_layers
==
1
else
config
.
dp_ratio
dropout
=
0
if
config
[
"n_layers"
]
==
1
else
config
[
"dp_ratio"
]
self
.
rnn
=
nn
.
LSTM
(
input_size
=
input_size
,
hidden_size
=
config
.
d_hidden
,
self
.
rnn
=
nn
.
LSTM
(
input_size
=
input_size
,
hidden_size
=
config
[
"d_hidden"
],
num_layers
=
config
.
n_layers
,
dropout
=
dropout
,
num_layers
=
config
[
"n_layers"
],
dropout
=
dropout
,
bidirectional
=
config
.
birnn
)
bidirectional
=
config
[
"birnn"
])
self
.
n_cells
=
config
[
"n_cells"
]
self
.
d_hidden
=
config
[
"d_hidden"
]
self
.
birnn
=
config
[
"birnn"
]
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
batch_size
=
inputs
.
size
()[
1
]
batch_size
=
inputs
.
size
()[
1
]
state_shape
=
self
.
config
.
n_cells
,
batch_size
,
self
.
config
.
d_hidden
state_shape
=
self
.
n_cells
,
batch_size
,
self
.
d_hidden
h0
=
c0
=
inputs
.
new_zeros
(
state_shape
)
h0
=
c0
=
inputs
.
new_zeros
(
state_shape
)
outputs
,
(
ht
,
ct
)
=
self
.
rnn
(
inputs
,
(
h0
,
c0
))
outputs
,
(
ht
,
ct
)
=
self
.
rnn
(
inputs
,
(
h0
,
c0
))
return
ht
[
-
1
]
if
not
self
.
config
.
birnn
else
ht
[
-
2
:].
transpose
(
0
,
1
).
contiguous
().
view
(
batch_size
,
-
1
)
return
ht
[
-
1
]
if
not
self
.
birnn
else
ht
[
-
2
:].
transpose
(
0
,
1
).
contiguous
().
view
(
batch_size
,
-
1
)
class
SNLIClassifier
(
nn
.
Module
):
class
SNLIClassifier
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
(
SNLIClassifier
,
self
).
__init__
()
super
(
SNLIClassifier
,
self
).
__init__
()
self
.
config
=
config
self
.
embed
=
nn
.
Embedding
(
config
[
"n_embed"
],
config
[
"d_embed"
])
self
.
embed
=
nn
.
Embedding
(
config
.
n_embed
,
config
.
d_embed
)
self
.
projection
=
Linear
(
config
[
"d_embed"
],
config
[
"d_proj"
])
self
.
projection
=
Linear
(
config
.
d_embed
,
config
.
d_proj
)
self
.
encoder
=
Encoder
(
config
)
self
.
encoder
=
Encoder
(
config
)
self
.
dropout
=
nn
.
Dropout
(
p
=
config
.
dp_ratio
)
self
.
dropout
=
nn
.
Dropout
(
p
=
config
[
"
dp_ratio
"
]
)
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
seq_in_size
=
2
*
config
.
d_hidden
seq_in_size
=
2
*
config
[
"
d_hidden
"
]
if
self
.
config
.
birnn
:
if
config
[
"
birnn
"
]
:
seq_in_size
*=
2
seq_in_size
*=
2
lin_config
=
[
seq_in_size
]
*
2
lin_config
=
[
seq_in_size
]
*
2
self
.
out
=
nn
.
Sequential
(
self
.
out
=
nn
.
Sequential
(
...
@@ -310,15 +315,17 @@ class TestConvert(unittest.TestCase):
...
@@ -310,15 +315,17 @@ class TestConvert(unittest.TestCase):
Linear
(
*
lin_config
),
Linear
(
*
lin_config
),
self
.
relu
,
self
.
relu
,
self
.
dropout
,
self
.
dropout
,
Linear
(
seq_in_size
,
config
.
d_out
))
Linear
(
seq_in_size
,
config
[
"d_out"
]))
self
.
fix_emb
=
config
[
"fix_emb"
]
self
.
project
=
config
[
"projection"
]
def
forward
(
self
,
premise
,
hypothesis
):
def
forward
(
self
,
premise
,
hypothesis
):
prem_embed
=
self
.
embed
(
premise
)
prem_embed
=
self
.
embed
(
premise
)
hypo_embed
=
self
.
embed
(
hypothesis
)
hypo_embed
=
self
.
embed
(
hypothesis
)
if
self
.
config
.
fix_emb
:
if
self
.
fix_emb
:
prem_embed
=
prem_embed
.
detach
()
prem_embed
=
prem_embed
.
detach
()
hypo_embed
=
hypo_embed
.
detach
()
hypo_embed
=
hypo_embed
.
detach
()
if
self
.
config
.
project
ion
:
if
self
.
project
:
prem_embed
=
self
.
relu
(
self
.
projection
(
prem_embed
))
prem_embed
=
self
.
relu
(
self
.
projection
(
prem_embed
))
hypo_embed
=
self
.
relu
(
self
.
projection
(
hypo_embed
))
hypo_embed
=
self
.
relu
(
self
.
projection
(
hypo_embed
))
premise
=
self
.
encoder
(
prem_embed
)
premise
=
self
.
encoder
(
prem_embed
)
...
@@ -326,23 +333,24 @@ class TestConvert(unittest.TestCase):
...
@@ -326,23 +333,24 @@ class TestConvert(unittest.TestCase):
scores
=
self
.
out
(
torch
.
cat
([
premise
,
hypothesis
],
1
))
scores
=
self
.
out
(
torch
.
cat
([
premise
,
hypothesis
],
1
))
return
scores
return
scores
class
Config
:
Config
=
{
n_embed
=
100
"n_embed"
:
100
,
d_embed
=
100
"d_embed"
:
100
,
d_proj
=
300
"d_proj"
:
300
,
dp_ratio
=
0.0
# For deterministic testing TODO: change by fixing seed in checkTrace?
"dp_ratio"
:
0.0
,
# For deterministic testing TOD": change by fixing seed in checkTrace?,
d_hidden
=
30
"d_hidden"
:
30
,
birnn
=
True
"birnn"
:
True
,
d_out
=
300
"d_out"
:
300
,
fix_emb
=
True
"fix_emb"
:
True
,
projection
=
True
"projection"
:
True
,
n_layers
=
2
"n_layers"
:
2
,
n_cells
=
4
# 2 * n_layers because birnn = True
"n_cells"
:
4
# 2 * n_layers because birnn = True,
}
premise
=
torch
.
LongTensor
(
48
,
64
).
random_
(
0
,
100
)
premise
=
torch
.
LongTensor
(
48
,
64
).
random_
(
0
,
100
)
hypothesis
=
torch
.
LongTensor
(
24
,
64
).
random_
(
0
,
100
)
hypothesis
=
torch
.
LongTensor
(
24
,
64
).
random_
(
0
,
100
)
self
.
checkExportImport
(
SNLIClassifier
(
Config
()
),
(
premise
,
hypothesis
))
self
.
checkExportImport
(
SNLIClassifier
(
Config
),
(
premise
,
hypothesis
))
def
test_super_resolution
(
self
):
def
test_super_resolution
(
self
):
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
...
@@ -367,16 +375,16 @@ class TestConvert(unittest.TestCase):
...
@@ -367,16 +375,16 @@ class TestConvert(unittest.TestCase):
net
=
Net
(
upscale_factor
=
4
)
net
=
Net
(
upscale_factor
=
4
)
self
.
checkExportImport
(
net
,
(
torch
.
rand
(
5
,
1
,
32
,
32
),))
self
.
checkExportImport
(
net
,
(
torch
.
rand
(
5
,
1
,
32
,
32
),))
@
unittest
.
skip
(
'Need to support op
erator prim::ListUnpack
'
)
# FIXME
@
unittest
.
skip
(
'Need to support
Lo
op'
)
# FIXME
def
test_time_sequence_prediction
(
self
):
def
test_time_sequence_prediction
(
self
):
class
Sequence
(
torch
.
jit
.
ScriptModule
):
class
Sequence
(
nn
.
Module
):
#
torch.jit.ScriptModule
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Sequence
,
self
).
__init__
()
super
(
Sequence
,
self
).
__init__
()
self
.
lstm1
=
nn
.
LSTMCell
(
1
,
51
)
self
.
lstm1
=
nn
.
LSTMCell
(
1
,
51
)
self
.
lstm2
=
nn
.
LSTMCell
(
51
,
51
)
self
.
lstm2
=
nn
.
LSTMCell
(
51
,
51
)
self
.
linear
=
nn
.
Linear
(
51
,
1
)
self
.
linear
=
nn
.
Linear
(
51
,
1
)
@
torch
.
jit
.
script_method
#
@torch.jit.script_method
def
forward
(
self
,
input
):
def
forward
(
self
,
input
):
# TODO: add future as input with default val
# TODO: add future as input with default val
# see https://github.com/pytorch/pytorch/issues/8724
# see https://github.com/pytorch/pytorch/issues/8724
...
@@ -414,7 +422,7 @@ class TestConvert(unittest.TestCase):
...
@@ -414,7 +422,7 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
Traced
(),
(
torch
.
rand
(
3
,
4
),))
self
.
checkExportImport
(
Traced
(),
(
torch
.
rand
(
3
,
4
),))
@
unittest
.
skip
(
'
Unsupported callmethod encode
'
)
# FIXME
@
unittest
.
skip
(
'
incorrectly assigned weights
'
)
# FIXME
def
test_vae
(
self
):
def
test_vae
(
self
):
class
VAE
(
nn
.
Module
):
class
VAE
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -449,11 +457,11 @@ class TestConvert(unittest.TestCase):
...
@@ -449,11 +457,11 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
VAE
().
eval
(),
(
torch
.
rand
(
128
,
1
,
28
,
28
),))
self
.
checkExportImport
(
VAE
().
eval
(),
(
torch
.
rand
(
128
,
1
,
28
,
28
),))
@
unittest
.
skip
(
'torchvision models are not supported yet'
)
# FIXME
def
test_torchvision_resnet18
(
self
):
def
test_torchvision_resnet18
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
@
unittest
.
skip
(
'Unsupported CallMethod _forward_impl'
)
# FIXME
def
test_resnet
(
self
):
def
test_resnet
(
self
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
"""1x1 convolution"""
"""1x1 convolution"""
...
@@ -464,7 +472,7 @@ class TestConvert(unittest.TestCase):
...
@@ -464,7 +472,7 @@ class TestConvert(unittest.TestCase):
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
False
)
padding
=
1
,
bias
=
False
)
class
BasicBlock
(
torch
.
jit
.
ScriptModule
):
class
BasicBlock
(
nn
.
Module
):
#
torch.jit.ScriptModule
expansion
=
1
expansion
=
1
__constants__
=
[
'downsample'
]
__constants__
=
[
'downsample'
]
...
@@ -478,7 +486,8 @@ class TestConvert(unittest.TestCase):
...
@@ -478,7 +486,8 @@ class TestConvert(unittest.TestCase):
self
.
downsample
=
downsample
self
.
downsample
=
downsample
self
.
stride
=
stride
self
.
stride
=
stride
@
torch
.
jit
.
script_method
# NOTE: jit cannot be annotated, otherwise, module id is not matched for recorded arguments
#@torch.jit.script_method
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
residual
=
x
residual
=
x
...
@@ -497,7 +506,8 @@ class TestConvert(unittest.TestCase):
...
@@ -497,7 +506,8 @@ class TestConvert(unittest.TestCase):
return
out
return
out
class
ResNet
(
torch
.
jit
.
ScriptModule
):
# NOTE: cannot inherit torch.jit.ScriptModule, otherwise, there would be error: 'RecursiveScriptModule' object has no attribute 'graph'
class
ResNet
(
nn
.
Module
):
#torch.jit.ScriptModule
__constants__
=
[
'layer1'
,
'layer2'
,
'layer3'
,
'layer4'
]
__constants__
=
[
'layer1'
,
'layer2'
,
'layer3'
,
'layer4'
]
def
__init__
(
self
,
block
,
layers
,
num_classes
=
1000
):
def
__init__
(
self
,
block
,
layers
,
num_classes
=
1000
):
...
@@ -538,7 +548,8 @@ class TestConvert(unittest.TestCase):
...
@@ -538,7 +548,8 @@ class TestConvert(unittest.TestCase):
return
nn
.
Sequential
(
*
layers
)
return
nn
.
Sequential
(
*
layers
)
@
torch
.
jit
.
script_method
# NOTE: jit cannot be annotated, otherwise, module id is not matched for recorded arguments
#@torch.jit.script_method
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
bn1
(
x
)
...
@@ -558,10 +569,11 @@ class TestConvert(unittest.TestCase):
...
@@ -558,10 +569,11 @@ class TestConvert(unittest.TestCase):
resnet18
=
ResNet
(
BasicBlock
,
[
2
,
2
,
2
,
2
])
resnet18
=
ResNet
(
BasicBlock
,
[
2
,
2
,
2
,
2
])
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
()
,
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
self
.
checkExportImport
(
resnet18
,
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
@
unittest
.
skip
(
'torchvision models are not supported yet'
)
# FIXME
def
test_alexnet
(
self
):
def
test_alexnet
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
model
=
torchvision
.
models
.
AlexNet
()
model
=
torchvision
.
models
.
AlexNet
()
self
.
checkExportImport
(
model
,
(
x
,))
self
.
checkExportImport
(
model
,
(
x
,))
test/ut/retiarii/test_convert_basic.py
0 → 100644
View file @
58d5c2fa
import
os
import
sys
import
unittest
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
# following pytorch v1.7.1
class
TestConvert
(
unittest
.
TestCase
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
if
check_value
:
self
.
assertEqual
(
len
(
converted_output
),
len
(
expected_output
))
for
a
,
b
in
zip
(
converted_output
,
expected_output
):
if
hasattr
(
a
,
'dtype'
)
and
a
.
dtype
==
torch
.
bool
:
self
.
assertEqual
((
a
^
b
),
False
)
elif
isinstance
((
a
-
b
),
int
):
self
.
assertEqual
((
a
-
b
),
0
)
else
:
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
return
converted_model
# skip torch.Tensor.new_tensor as it is not supported by jit
def
test_basic_new_full
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
# requires_grad is not supported by jit
# aten::new_full(Tensor self, int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor):
# Keyword argument requires_grad unknown.
out
=
x
.
new_full
((
3
,
4
),
3.141592
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
'cpu'
))
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
ones
((
2
,),
dtype
=
torch
.
float64
),
))
def
test_basic_new_empty
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
new_empty
((
2
,
3
),
dtype
=
torch
.
int8
,
device
=
torch
.
device
(
'cpu'
))
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
ones
(()),
),
check_value
=
False
)
# skip torch.Tensor.new_ones as it is not supported by jit
# requires_grad=False is not supported by jit
def
test_basic_new_zeros
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
new_zeros
((
2
,
3
))
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
((),
dtype
=
torch
.
int32
),
))
def
test_basic_is_cuda
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
tensor
([
x
.
is_cuda
],
dtype
=
torch
.
bool
,
device
=
torch
.
device
(
'cpu'
))
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
((),
dtype
=
torch
.
int32
),
))
# is_quantized
# is_meta
# device
# grad
# ndim
# T
# real
# imag
def
test_basic_abs
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
abs
()
out11
=
x
.
absolute
()
out2
=
torch
.
abs
(
x
)
#out3 = x.abs_()
#out33 = x.absolute_()
return
out1
,
out11
,
out2
#, out3, out33
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1
,
-
2
,
3
]),
))
# TODO: topological sort should be improved
#def forward(self, x__1):
# __Acos2 = x__1.acos()
# __Acos_3 = x__1.acos_()
# __Acos1 = x__1.acos()
# __TupleConstruct4 = (__Acos1,__Acos2,__Acos_3)
# return __TupleConstruct4
def
test_basic_acos_asin_atan
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
acos
()
out2
=
torch
.
acos
(
x
)
# TODO: add back this line
#out = x.acos_()
out3
=
x
.
asin
()
out4
=
torch
.
asin
(
x
)
out5
=
x
.
atan
()
out6
=
torch
.
atan
(
x
)
out7
=
x
.
atan2
(
y
)
out8
=
torch
.
atan2
(
x
,
y
)
return
out1
,
out2
,
out3
,
out4
,
out5
,
out6
,
out7
,
out8
#, out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1.0
,
-
0.5
,
0.2
]),
torch
.
tensor
([
1.0
,
0.6
,
-
0.3
]),
))
# arccos is not supported by jit
def
test_basic_add
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
t
=
torch
.
tensor
([
-
1.0
,
-
0.5
,
0.2
])
out1
=
x
.
add
(
t
)
out2
=
x
.
add
(
t
,
alpha
=
2
)
#out3 = x.add_(t)
return
out1
,
out2
#, out3
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1.0
,
-
0.5
,
0.2
]),
))
def
test_basic_addbmm
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
,
m
):
out1
=
x
.
addbmm
(
y
,
z
,
beta
=
2
,
alpha
=
3
)
out2
=
torch
.
addbmm
(
x
,
y
,
z
,
beta
=
2
,
alpha
=
3
)
#out3 = x.addbmm_(y, z, beta=2, alpha=3)
out3
=
m
.
baddbmm
(
y
,
z
,
beta
=
2
,
alpha
=
3
)
out4
=
torch
.
baddbmm
(
m
,
y
,
z
,
beta
=
2
,
alpha
=
3
)
out5
=
torch
.
bmm
(
y
,
z
)
# deterministic is not supported by jit
return
out1
,
out2
,
out3
,
out4
,
out5
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
3
,
5
),
torch
.
randn
(
10
,
3
,
4
),
torch
.
randn
(
10
,
4
,
5
),
torch
.
randn
(
10
,
3
,
5
),
))
def
test_basic_addcdiv
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addcdiv
(
y
,
z
,
value
=
2
)
out2
=
torch
.
addcdiv
(
x
,
y
,
z
,
value
=
2
)
# addcdiv_
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
1
,
3
),
torch
.
randn
(
3
,
1
),
torch
.
randn
(
1
,
3
),
))
def
test_basic_addcmul
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addcmul
(
y
,
z
,
value
=
0.1
)
out2
=
torch
.
addcmul
(
x
,
y
,
z
,
value
=
0.1
)
# addcmul_
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
1
,
3
),
torch
.
randn
(
3
,
1
),
torch
.
randn
(
1
,
3
),
))
def
test_basic_addmm
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addmm
(
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
out2
=
torch
.
addmm
(
x
,
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
# addmm_
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
2
,
3
),
torch
.
randn
(
2
,
3
),
torch
.
randn
(
3
,
3
),
))
def
test_basic_addmv
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addmv
(
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
out2
=
torch
.
addmv
(
x
,
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
2
),
torch
.
randn
(
2
,
3
),
torch
.
randn
(
3
),
))
def
test_basic_addr
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addr
(
y
,
z
,
beta
=
2
,
alpha
=
3
)
out2
=
torch
.
addr
(
x
,
y
,
z
,
beta
=
2
,
alpha
=
3
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
zeros
(
3
,
2
),
torch
.
arange
(
1.
,
4.
),
torch
.
arange
(
1.
,
3.
),
))
def
test_basic_allclose
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
allclose
(
y
,
rtol
=
1e-05
,
atol
=
1e-08
,
equal_nan
=
False
)
out2
=
torch
.
allclose
(
x
,
y
,
rtol
=
1e-05
,
atol
=
1e-08
,
equal_nan
=
False
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
10000.
,
1e-07
]),
torch
.
tensor
([
10000.1
,
1e-08
]),
))
def
test_basic_angle
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
angle
()
out2
=
torch
.
angle
(
x
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1
+
1j
,
-
2
+
2j
,
3
-
3j
]),
))
# skip apply_(callable) for now
def
test_basic_argmax_argmin
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
argmax
()
out2
=
torch
.
argmax
(
x
)
out3
=
x
.
argmax
(
dim
=
1
)
out4
=
torch
.
argmax
(
x
,
dim
=
1
)
out5
=
x
.
argmax
(
dim
=
1
,
keepdim
=
True
)
o1
=
x
.
argmin
()
o2
=
torch
.
argmin
(
x
)
o3
=
x
.
argmin
(
dim
=
1
)
o4
=
x
.
argmin
(
dim
=
1
,
keepdim
=
True
)
return
out1
,
out2
,
out3
,
out4
,
out5
,
o1
,
o2
,
o3
,
o4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
def
test_basic_argsort
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
argsort
()
out2
=
x
.
argsort
(
dim
=
1
)
out3
=
x
.
argsort
(
dim
=
1
,
descending
=
True
)
out4
=
torch
.
argsort
(
x
,
dim
=
1
,
descending
=
True
)
return
out1
,
out2
,
out3
,
out4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
# skip backward(gradient=None, retain_graph=None, create_graph=False)
def
test_basic_bernoulli
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
# generator=torch.Generator() is not supported by jit
out
=
x
.
bernoulli
()
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
ones
(
3
,
3
),
))
# bfloat16/bool/byte/char is not supported by jit
def
test_basic_bincount
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
bincount
()
out2
=
torch
.
bincount
(
x
)
out3
=
x
.
bincount
(
weights
=
y
)
out4
=
x
.
bincount
(
weights
=
y
,
minlength
=
2
)
return
out1
,
out2
,
out3
,
out4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randint
(
0
,
8
,
(
5
,),
dtype
=
torch
.
int64
),
torch
.
linspace
(
0
,
1
,
steps
=
5
),
))
def
test_basic_bitwise
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
bitwise_not
()
out2
=
x
.
bitwise_and
(
y
)
out3
=
x
.
bitwise_or
(
y
)
out4
=
x
.
bitwise_xor
(
y
)
return
out1
,
out2
,
out3
,
out4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1
,
-
2
,
3
],
dtype
=
torch
.
int8
),
torch
.
tensor
([
1
,
0
,
3
],
dtype
=
torch
.
int8
),
))
# cauchy_ is not supported yet
def
test_ceil
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
ceil
()
return
out1
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
),
))
\ No newline at end of file
test/ut/retiarii/test_convert_models.py
0 → 100644
View file @
58d5c2fa
test/ut/retiarii/test_convert_operators.py
0 → 100644
View file @
58d5c2fa
This diff is collapsed.
Click to expand it.
test/ut/retiarii/test_convert_pytorch.py
0 → 100644
View file @
58d5c2fa
This diff is collapsed.
Click to expand it.
test/ut/retiarii/test_highlevel_apis.py
View file @
58d5c2fa
...
@@ -167,6 +167,7 @@ class TestHighLevelAPI(unittest.TestCase):
...
@@ -167,6 +167,7 @@ class TestHighLevelAPI(unittest.TestCase):
mutator
=
mutators
[
0
].
bind_sampler
(
EnuemrateSampler
())
mutator
=
mutators
[
0
].
bind_sampler
(
EnuemrateSampler
())
model1
=
mutator
.
apply
(
model
)
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment