Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
58d5c2fa
Unverified
Commit
58d5c2fa
authored
Feb 22, 2021
by
QuanluZhang
Committed by
GitHub
Feb 22, 2021
Browse files
[retiarii] refactor of pytorch operators (#3365)
parent
59521d33
Changes
15
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
4020 additions
and
310 deletions
+4020
-310
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+35
-6
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+244
-120
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+0
-26
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+1
-1
nni/retiarii/operation.py
nni/retiarii/operation.py
+47
-52
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+407
-29
nni/retiarii/utils.py
nni/retiarii/utils.py
+8
-0
pipelines/fast-test.yml
pipelines/fast-test.yml
+3
-1
test/ut/retiarii/inject_nn.py
test/ut/retiarii/inject_nn.py
+280
-0
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+87
-75
test/ut/retiarii/test_convert_basic.py
test/ut/retiarii/test_convert_basic.py
+283
-0
test/ut/retiarii/test_convert_models.py
test/ut/retiarii/test_convert_models.py
+0
-0
test/ut/retiarii/test_convert_operators.py
test/ut/retiarii/test_convert_operators.py
+1390
-0
test/ut/retiarii/test_convert_pytorch.py
test/ut/retiarii/test_convert_pytorch.py
+1234
-0
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+1
-0
No files found.
nni/retiarii/codegen/pytorch.py
View file @
58d5c2fa
import
logging
from
typing
import
List
from
typing
import
List
,
Tuple
,
Any
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
...
...
@@ -32,9 +32,26 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
raise
IllegalGraphError
(
node
.
graph
,
'Node {} has bad inputs'
.
format
(
node
.
name
))
def
_format_inputs
(
node
:
Node
)
->
List
[
str
]:
def
_format_inputs
(
node
:
Node
)
->
Tuple
[
List
[
str
],
List
[
Any
]]:
"""
Format the inputs of a given node
Parameters
----------
node : Node
a graph node, get and format its inputs
Returns
-------
list
the list of input names
list
the list of input values, if an input is simple type, record its value,
otherwise the value is None
"""
edges
=
_sorted_incoming_edges
(
node
)
inputs
=
[]
inputs_value
=
[]
for
edge
in
edges
:
if
edge
.
head
.
name
==
'_inputs'
:
assert
isinstance
(
edge
.
head_slot
,
int
)
...
...
@@ -44,14 +61,21 @@ def _format_inputs(node: Node) -> List[str]:
else
:
# when input has no name, e.g., forward(*_inputs)
inputs
.
append
(
'_inputs[{}]'
.
format
(
edge
.
head_slot
))
inputs_value
.
append
(
None
)
else
:
if
edge
.
head_slot
is
None
:
# when the input comes from a single-output operator
inputs
.
append
(
'{}'
.
format
(
edge
.
head
.
name
))
if
edge
.
head
.
operation
.
type
in
(
'prim::Constant'
,
'prim::GetAttr'
)
and
\
'value'
in
edge
.
head
.
operation
.
parameters
:
inputs_value
.
append
(
edge
.
head
.
operation
.
parameters
[
'value'
])
else
:
inputs_value
.
append
(
None
)
else
:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs
.
append
(
'{}[{}]'
.
format
(
edge
.
head
.
name
,
edge
.
head_slot
))
return
inputs
inputs_value
.
append
(
None
)
return
inputs
,
inputs_value
def
_remove_prefix
(
names
,
graph_name
):
...
...
@@ -80,6 +104,8 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
node_codes
=
[]
for
node
in
nodes
:
if
node
.
operation
:
if
node
.
operation
.
type
==
'shared'
:
continue
pkg_name
=
node
.
operation
.
get_import_pkg
()
if
pkg_name
is
not
None
:
import_pkgs
.
add
(
pkg_name
)
...
...
@@ -101,12 +127,15 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
sorted_nodes
=
graph
.
topo_sort
()
for
node
in
sorted_nodes
:
if
node
.
operation
:
inputs
=
_format_inputs
(
node
)
inputs
,
inputs_value
=
_format_inputs
(
node
)
inputs
=
_remove_prefix
(
inputs
,
graph_name
)
node_name
=
_remove_prefix
(
node
.
name
,
graph_name
)
edge_codes
.
append
(
node
.
operation
.
to_forward_code
(
node_name
,
node_name
,
inputs
))
submodule_name
=
node_name
if
node
.
operation
.
type
==
'shared'
:
submodule_name
=
_remove_prefix
(
node
.
operation
.
parameters
[
'reference'
],
graph_name
)
edge_codes
.
append
(
node
.
operation
.
to_forward_code
(
submodule_name
,
node_name
,
inputs
,
inputs_value
))
output_names
=
_format_inputs
(
graph
.
output_node
)
output_names
,
_
=
_format_inputs
(
graph
.
output_node
)
output_names
=
_remove_prefix
(
output_names
,
graph_name
)
if
not
output_names
:
raise
RuntimeError
(
'"forward" function should have return value(s): {}, {}, {}'
.
format
(
output_names
,
graph_name
,
graph
.
output_node
))
...
...
nni/retiarii/converter/graph_gen.py
View file @
58d5c2fa
This diff is collapsed.
Click to expand it.
nni/retiarii/converter/op_types.py
View file @
58d5c2fa
...
...
@@ -9,34 +9,8 @@ class OpTypeName(str, Enum):
"""
Attr
=
'Attr'
Constant
=
'Constant'
ListConstruct
=
'ListConstruct'
TupleConstruct
=
'TupleConstruct'
LayerChoice
=
'LayerChoice'
InputChoice
=
'InputChoice'
ValueChoice
=
'ValueChoice'
Placeholder
=
'Placeholder'
MergedSlice
=
'MergedSlice'
# deal with aten op
BasicOpsPT
=
{
'aten::mean'
:
'Mean'
,
'aten::relu'
:
'Relu'
,
'aten::add'
:
'Add'
,
'aten::__getitem__'
:
'getitem'
,
'aten::append'
:
'Append'
,
'aten::len'
:
'Len'
,
'aten::slice'
:
'Slice'
,
'aten::cat'
:
'Cat'
,
'aten::size'
:
'Size'
,
'aten::view'
:
'View'
,
'aten::reshape'
:
'Reshape'
,
'aten::eq'
:
'Eq'
,
'aten::Bool'
:
'Bool'
,
'aten::empty'
:
'Empty'
,
'aten::zeros'
:
'Zeros'
,
'aten::chunk'
:
'Chunk'
,
'aten::add_'
:
'Add_'
# %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
BasicOpsTF
=
{}
nni/retiarii/nn/pytorch/mutator.py
View file @
58d5c2fa
...
...
@@ -45,7 +45,7 @@ class ValueChoiceMutator(Mutator):
chosen
=
self
.
choice
(
self
.
candidates
)
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
target
.
update_operation
(
'prim::Constant'
,
{
'value'
:
chosen
})
target
.
update_operation
(
'prim::Constant'
,
{
'type'
:
type
(
chosen
).
__name__
,
'value'
:
chosen
})
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
...
...
nni/retiarii/operation.py
View file @
58d5c2fa
...
...
@@ -83,6 +83,31 @@ class Operation:
class
PyTorchOperation
(
Operation
):
@
classmethod
def
_find_subclass
(
cls
,
subclass_name
):
if
cls
.
to_class_name
(
subclass_name
)
is
not
None
:
subclass_name
=
'ModuleOperator'
if
cls
.
is_functional
(
subclass_name
):
subclass_name
=
'FunctionalOperator'
for
subclass
in
cls
.
__subclasses__
():
if
hasattr
(
subclass
,
'_ori_type_name'
)
and
\
subclass_name
in
subclass
.
_ori_type_name
:
return
subclass
return
cls
@
classmethod
def
to_class_name
(
cls
,
type_name
)
->
str
:
if
type_name
.
startswith
(
'__torch__.'
):
return
type_name
[
len
(
'__torch__.'
):]
elif
type_name
.
startswith
(
'__mutated__.'
):
return
type_name
[
len
(
'__mutated__.'
):]
else
:
return
None
@
classmethod
def
is_functional
(
cls
,
type_name
)
->
bool
:
return
type_name
.
startswith
(
'Function.'
)
def
_to_class_name
(
self
)
->
str
:
if
self
.
type
.
startswith
(
'__torch__.'
):
return
self
.
type
[
len
(
'__torch__.'
):]
...
...
@@ -106,59 +131,27 @@ class PyTorchOperation(Operation):
return
f
'self.
{
field
}
=
{
self
.
_to_class_name
()
}
(
{
kw_params
}
)'
return
None
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
from
.converter.op_types
import
OpTypeName
if
self
.
_to_class_name
()
is
not
None
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
elif
self
.
type
.
startswith
(
'Function.'
):
func_name
=
self
.
type
[
len
(
'Function.'
):]
return
f
'
{
output
}
= F.
{
func_name
}
(
{
", "
.
join
(
inputs
)
}
)'
elif
self
.
type
==
'prim::Constant'
:
if
self
.
parameters
:
value
=
self
.
parameters
[
'value'
]
else
:
value
=
None
return
f
'
{
output
}
=
{
value
}
'
elif
self
.
type
==
'prim::ListConstruct'
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
elif
self
.
type
==
'prim::TupleConstruct'
:
return
f
'
{
output
}
= (
{
", "
.
join
(
inputs
)
}
)'
elif
self
.
type
==
'prim::GetAttr'
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'input'
]
}
.
{
self
.
parameters
[
'name'
]
}
"
elif
self
.
type
==
'aten::mean'
:
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
elif
self
.
type
==
'aten::__getitem__'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
elif
self
.
type
==
'aten::append'
:
assert
len
(
inputs
)
==
2
return
f
'_,
{
output
}
=
{
inputs
[
0
]
}
.append(
{
inputs
[
1
]
}
),
{
inputs
[
0
]
}
'
elif
self
.
type
==
'aten::cat'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::add'
:
return
f
'
{
output
}
= '
+
' + '
.
join
(
inputs
)
elif
self
.
type
==
OpTypeName
.
MergedSlice
:
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
slices
=
[]
dim
=
int
((
len
(
inputs
)
-
1
)
/
4
)
for
i
in
range
(
dim
):
slices
.
append
(
f
'
{
inputs
[
i
*
4
+
2
]
}
:
{
inputs
[
i
*
4
+
3
]
}
:
{
inputs
[
i
*
4
+
4
]
}
'
)
slice_str
=
','
.
join
(
slices
)
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
slice_str
}
]'
elif
self
.
type
==
'aten::size'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.size(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::view'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::reshape'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.reshape(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::slice'
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
"""
Parameters
----------
field : str
the name of member submodule
output : str
the output name (lvalue) of this line of code
inputs : List[str]
variables used in this line of code
inputs_value : List[Any]
some variables are actually constant, their real values are recorded in ```inputs_value```.
if not constant, we simply put None at the corresponding index
Returns
-------
str
generated code line
"""
if
self
.
type
==
'aten::slice'
:
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
elif
self
.
type
==
'aten::Bool'
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
...
...
@@ -212,6 +205,8 @@ class Cell(PyTorchOperation):
# TODO: ugly, think about how to refactor this part
return
_convert_name
(
self
.
cell_name
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
]
=
None
)
->
str
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
class
_IOPseudoOperation
(
Operation
):
"""
...
...
nni/retiarii/operation_def/torch_op_def.py
View file @
58d5c2fa
This diff is collapsed.
Click to expand it.
nni/retiarii/utils.py
View file @
58d5c2fa
...
...
@@ -162,6 +162,14 @@ def _get_module_name(cls):
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
module_name
=
main_file_path
.
stem
break
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if
f
'
{
cls
.
__module__
}
.
{
cls
.
__name__
}
'
==
'torch.nn.modules.rnn.LSTM'
:
module_name
=
cls
.
__module__
return
module_name
...
...
pipelines/fast-test.yml
View file @
58d5c2fa
...
...
@@ -250,7 +250,9 @@ stages:
-
script
:
|
cd test
python -m pytest ut
python -m pytest ut --ignore=ut/retiarii/test_convert_basic.py \
--ignore=ut/retiarii/test_convert_operators.py \
--ignore=ut/retiarii/test_convert_pytorch.py
displayName
:
Python unit test
-
script
:
|
...
...
test/ut/retiarii/inject_nn.py
0 → 100644
View file @
58d5c2fa
import
inspect
import
logging
import
torch
import
torch.nn
as
nn
from
nni.retiarii.utils
import
add_record
,
del_record
,
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
def
wrap_module
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
original_class
.
bak_init_for_inject
=
orig_init
if
hasattr
(
original_class
,
'__del__'
):
orig_del
=
original_class
.
__del__
original_class
.
bak_del_for_inject
=
orig_del
else
:
orig_del
=
None
original_class
.
bak_del_for_inject
=
None
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
def
__del__
(
self
):
del_record
(
id
(
self
))
if
orig_del
is
not
None
:
orig_del
(
self
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__del__
=
__del__
return
original_class
def
unwrap_module
(
wrapped_class
):
if
hasattr
(
wrapped_class
,
'bak_init_for_inject'
):
wrapped_class
.
__init__
=
wrapped_class
.
bak_init_for_inject
delattr
(
wrapped_class
,
'bak_init_for_inject'
)
if
hasattr
(
wrapped_class
,
'bak_del_for_inject'
):
if
wrapped_class
.
bak_del_for_inject
is
not
None
:
wrapped_class
.
__del__
=
wrapped_class
.
bak_del_for_inject
delattr
(
wrapped_class
,
'bak_del_for_inject'
)
return
None
def
remove_inject_pytorch_nn
():
Identity
=
unwrap_module
(
nn
.
Identity
)
Linear
=
unwrap_module
(
nn
.
Linear
)
Conv1d
=
unwrap_module
(
nn
.
Conv1d
)
Conv2d
=
unwrap_module
(
nn
.
Conv2d
)
Conv3d
=
unwrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
unwrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
unwrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
unwrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
unwrap_module
(
nn
.
Threshold
)
ReLU
=
unwrap_module
(
nn
.
ReLU
)
Hardtanh
=
unwrap_module
(
nn
.
Hardtanh
)
ReLU6
=
unwrap_module
(
nn
.
ReLU6
)
Sigmoid
=
unwrap_module
(
nn
.
Sigmoid
)
Tanh
=
unwrap_module
(
nn
.
Tanh
)
Softmax
=
unwrap_module
(
nn
.
Softmax
)
Softmax2d
=
unwrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
unwrap_module
(
nn
.
LogSoftmax
)
ELU
=
unwrap_module
(
nn
.
ELU
)
SELU
=
unwrap_module
(
nn
.
SELU
)
CELU
=
unwrap_module
(
nn
.
CELU
)
GLU
=
unwrap_module
(
nn
.
GLU
)
GELU
=
unwrap_module
(
nn
.
GELU
)
Hardshrink
=
unwrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
unwrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
unwrap_module
(
nn
.
LogSigmoid
)
Softplus
=
unwrap_module
(
nn
.
Softplus
)
Softshrink
=
unwrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
unwrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
unwrap_module
(
nn
.
PReLU
)
Softsign
=
unwrap_module
(
nn
.
Softsign
)
Softmin
=
unwrap_module
(
nn
.
Softmin
)
Tanhshrink
=
unwrap_module
(
nn
.
Tanhshrink
)
RReLU
=
unwrap_module
(
nn
.
RReLU
)
AvgPool1d
=
unwrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
unwrap_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
unwrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
unwrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
unwrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
unwrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
unwrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
unwrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
unwrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
unwrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
unwrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
unwrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
unwrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
unwrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
unwrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
unwrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
unwrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
unwrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
unwrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
unwrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
unwrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
unwrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
unwrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
unwrap_module
(
nn
.
Dropout
)
Dropout2d
=
unwrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
unwrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
unwrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
unwrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
unwrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
unwrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
unwrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
unwrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
unwrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
unwrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
unwrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
unwrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
unwrap_module
(
nn
.
RNNBase
)
RNN
=
unwrap_module
(
nn
.
RNN
)
LSTM
=
unwrap_module
(
nn
.
LSTM
)
GRU
=
unwrap_module
(
nn
.
GRU
)
RNNCellBase
=
unwrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
unwrap_module
(
nn
.
RNNCell
)
LSTMCell
=
unwrap_module
(
nn
.
LSTMCell
)
GRUCell
=
unwrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
unwrap_module
(
nn
.
PixelShuffle
)
Upsample
=
unwrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
unwrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
unwrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
unwrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
unwrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
unwrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
unwrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
unwrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
unwrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
unwrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
unwrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
unwrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
unwrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
unwrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
unwrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
unwrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
unwrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
unwrap_module
(
nn
.
Unfold
)
Fold
=
unwrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
unwrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
unwrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
unwrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
unwrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
unwrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
unwrap_module
(
nn
.
Transformer
)
Flatten
=
unwrap_module
(
nn
.
Flatten
)
Hardsigmoid
=
unwrap_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
unwrap_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
unwrap_module
(
nn
.
SiLU
)
Unflatten
=
unwrap_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
unwrap_module
(
nn
.
TripletMarginWithDistanceLoss
)
def
inject_pytorch_nn
():
Identity
=
wrap_module
(
nn
.
Identity
)
Linear
=
wrap_module
(
nn
.
Linear
)
Conv1d
=
wrap_module
(
nn
.
Conv1d
)
Conv2d
=
wrap_module
(
nn
.
Conv2d
)
Conv3d
=
wrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
wrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
wrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
wrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
wrap_module
(
nn
.
Threshold
)
ReLU
=
wrap_module
(
nn
.
ReLU
)
Hardtanh
=
wrap_module
(
nn
.
Hardtanh
)
ReLU6
=
wrap_module
(
nn
.
ReLU6
)
Sigmoid
=
wrap_module
(
nn
.
Sigmoid
)
Tanh
=
wrap_module
(
nn
.
Tanh
)
Softmax
=
wrap_module
(
nn
.
Softmax
)
Softmax2d
=
wrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
wrap_module
(
nn
.
LogSoftmax
)
ELU
=
wrap_module
(
nn
.
ELU
)
SELU
=
wrap_module
(
nn
.
SELU
)
CELU
=
wrap_module
(
nn
.
CELU
)
GLU
=
wrap_module
(
nn
.
GLU
)
GELU
=
wrap_module
(
nn
.
GELU
)
Hardshrink
=
wrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
wrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
wrap_module
(
nn
.
LogSigmoid
)
Softplus
=
wrap_module
(
nn
.
Softplus
)
Softshrink
=
wrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
wrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
wrap_module
(
nn
.
PReLU
)
Softsign
=
wrap_module
(
nn
.
Softsign
)
Softmin
=
wrap_module
(
nn
.
Softmin
)
Tanhshrink
=
wrap_module
(
nn
.
Tanhshrink
)
RReLU
=
wrap_module
(
nn
.
RReLU
)
AvgPool1d
=
wrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
wrap_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
wrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
wrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
wrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
wrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
wrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
wrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
wrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
wrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
wrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
wrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
wrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
wrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
wrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
wrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
wrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
wrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
wrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
wrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
wrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
wrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
wrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
wrap_module
(
nn
.
Dropout
)
Dropout2d
=
wrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
wrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
wrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
wrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
wrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
wrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
wrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
wrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
wrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
wrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
wrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
wrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
wrap_module
(
nn
.
RNNBase
)
RNN
=
wrap_module
(
nn
.
RNN
)
LSTM
=
wrap_module
(
nn
.
LSTM
)
GRU
=
wrap_module
(
nn
.
GRU
)
RNNCellBase
=
wrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
wrap_module
(
nn
.
RNNCell
)
LSTMCell
=
wrap_module
(
nn
.
LSTMCell
)
GRUCell
=
wrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
wrap_module
(
nn
.
PixelShuffle
)
Upsample
=
wrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
wrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
wrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
wrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
wrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
wrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
wrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
wrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
wrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
wrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
wrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
wrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
wrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
wrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
wrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
wrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
wrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
wrap_module
(
nn
.
Unfold
)
Fold
=
wrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
wrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
wrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
wrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
wrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
wrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
wrap_module
(
nn
.
Transformer
)
Flatten
=
wrap_module
(
nn
.
Flatten
)
Hardsigmoid
=
wrap_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
wrap_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
wrap_module
(
nn
.
SiLU
)
Unflatten
=
wrap_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
wrap_module
(
nn
.
TripletMarginWithDistanceLoss
)
test/ut/retiarii/test_convert.py
View file @
58d5c2fa
...
...
@@ -35,16 +35,29 @@ class MnistNet(nn.Module):
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
# NOTE: blackbox module cannot be placed within class or function
@
blackbox_module
class
Linear
(
nn
.
Module
):
def
__init__
(
self
,
d_embed
,
d_proj
):
super
().
__init__
()
self
.
linear
=
nn
.
Linear
(
d_embed
,
d_proj
)
def
forward
(
self
,
input
):
if
len
(
input
.
size
())
<=
2
:
return
self
.
linear
(
input
)
size
=
input
.
size
()[:
2
]
out
=
self
.
linear
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
TestConvert
(
unittest
.
TestCase
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
cv
in
current_values
:
for
idx
,
cv
in
enumerate
(
current_values
)
:
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
remove
(
cv
)
current_values
.
pop
(
idx
)
break
return
result
...
...
@@ -53,6 +66,9 @@ class TestConvert(unittest.TestCase):
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
from
.inject_nn
import
remove_inject_pytorch_nn
remove_inject_pytorch_nn
()
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
...
...
@@ -134,18 +150,17 @@ class TestConvert(unittest.TestCase):
model
=
DCGANGenerator
(
nz
,
ngf
,
nc
)
self
.
checkExportImport
(
model
,
input
)
@
unittest
.
skip
(
'this test has a if condition that needs to be handle'
)
# FIXME
def
test_neural_style
(
self
):
class
TransformerNet
(
torch
.
nn
.
Module
):
class
TransformerNet
(
nn
.
Module
):
def
__init__
(
self
):
super
(
TransformerNet
,
self
).
__init__
()
# Initial convolution layers
self
.
conv1
=
ConvLayer
(
3
,
32
,
kernel_size
=
9
,
stride
=
1
)
self
.
in1
=
torch
.
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
in1
=
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
conv2
=
ConvLayer
(
32
,
64
,
kernel_size
=
3
,
stride
=
2
)
self
.
in2
=
torch
.
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
in2
=
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
conv3
=
ConvLayer
(
64
,
128
,
kernel_size
=
3
,
stride
=
2
)
self
.
in3
=
torch
.
nn
.
InstanceNorm2d
(
128
,
affine
=
True
)
self
.
in3
=
nn
.
InstanceNorm2d
(
128
,
affine
=
True
)
# Residual layers
self
.
res1
=
ResidualBlock
(
128
)
self
.
res2
=
ResidualBlock
(
128
)
...
...
@@ -154,12 +169,12 @@ class TestConvert(unittest.TestCase):
self
.
res5
=
ResidualBlock
(
128
)
# Upsampling Layers
self
.
deconv1
=
UpsampleConvLayer
(
128
,
64
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
in4
=
torch
.
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
in4
=
nn
.
InstanceNorm2d
(
64
,
affine
=
True
)
self
.
deconv2
=
UpsampleConvLayer
(
64
,
32
,
kernel_size
=
3
,
stride
=
1
,
upsample
=
2
)
self
.
in5
=
torch
.
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
in5
=
nn
.
InstanceNorm2d
(
32
,
affine
=
True
)
self
.
deconv3
=
ConvLayer
(
32
,
3
,
kernel_size
=
9
,
stride
=
1
)
# Non-linearities
self
.
relu
=
torch
.
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
X
):
y
=
self
.
relu
(
self
.
in1
(
self
.
conv1
(
X
)))
...
...
@@ -175,19 +190,19 @@ class TestConvert(unittest.TestCase):
y
=
self
.
deconv3
(
y
)
return
y
class
ConvLayer
(
torch
.
nn
.
Module
):
class
ConvLayer
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
):
super
(
ConvLayer
,
self
).
__init__
()
reflection_padding
=
kernel_size
//
2
self
.
reflection_pad
=
torch
.
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
self
.
reflection_pad
=
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
def
forward
(
self
,
x
):
out
=
self
.
reflection_pad
(
x
)
out
=
self
.
conv2d
(
out
)
return
out
class
ResidualBlock
(
torch
.
nn
.
Module
):
class
ResidualBlock
(
nn
.
Module
):
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
...
...
@@ -196,10 +211,10 @@ class TestConvert(unittest.TestCase):
def
__init__
(
self
,
channels
):
super
(
ResidualBlock
,
self
).
__init__
()
self
.
conv1
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
in1
=
torch
.
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
in1
=
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
conv2
=
ConvLayer
(
channels
,
channels
,
kernel_size
=
3
,
stride
=
1
)
self
.
in2
=
torch
.
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
relu
=
torch
.
nn
.
ReLU
()
self
.
in2
=
nn
.
InstanceNorm2d
(
channels
,
affine
=
True
)
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
residual
=
x
...
...
@@ -208,7 +223,7 @@ class TestConvert(unittest.TestCase):
out
=
out
+
residual
return
out
class
UpsampleConvLayer
(
torch
.
nn
.
Module
):
class
UpsampleConvLayer
(
nn
.
Module
):
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
...
...
@@ -219,10 +234,10 @@ class TestConvert(unittest.TestCase):
super
(
UpsampleConvLayer
,
self
).
__init__
()
self
.
upsample
=
upsample
if
upsample
:
self
.
upsample_layer
=
torch
.
nn
.
Upsample
(
mode
=
'nearest'
,
scale_factor
=
upsample
)
self
.
upsample_layer
=
nn
.
Upsample
(
mode
=
'nearest'
,
scale_factor
=
upsample
)
reflection_padding
=
kernel_size
//
2
self
.
reflection_pad
=
torch
.
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
torch
.
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
self
.
reflection_pad
=
nn
.
ReflectionPad2d
(
reflection_padding
)
self
.
conv2d
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
)
def
forward
(
self
,
x
):
x_in
=
x
...
...
@@ -254,50 +269,40 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
Policy
(),
(
torch
.
rand
(
1
,
4
),))
@
unittest
.
skip
(
'Replaced init error.'
)
# FIXME
def
test_snli
(
self
):
class
Bottle
(
nn
.
Module
):
def
forward
(
self
,
input
):
if
len
(
input
.
size
())
<=
2
:
return
super
(
Bottle
,
self
).
forward
(
input
)
size
=
input
.
size
()[:
2
]
out
=
super
(
Bottle
,
self
).
forward
(
input
.
view
(
size
[
0
]
*
size
[
1
],
-
1
))
return
out
.
view
(
size
[
0
],
size
[
1
],
-
1
)
class
Linear
(
Bottle
,
nn
.
Linear
):
pass
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
Encoder
,
self
).
__init__
()
self
.
config
=
config
input_size
=
config
.
d_proj
if
config
.
projection
else
config
.
d_embed
dropout
=
0
if
config
.
n_layers
==
1
else
config
.
dp_ratio
self
.
rnn
=
nn
.
LSTM
(
input_size
=
input_size
,
hidden_size
=
config
.
d_hidden
,
num_layers
=
config
.
n_layers
,
dropout
=
dropout
,
bidirectional
=
config
.
birnn
)
#self.config = config
input_size
=
config
[
"d_proj"
]
if
config
[
"projection"
]
else
config
[
"d_embed"
]
dropout
=
0
if
config
[
"n_layers"
]
==
1
else
config
[
"dp_ratio"
]
self
.
rnn
=
nn
.
LSTM
(
input_size
=
input_size
,
hidden_size
=
config
[
"d_hidden"
],
num_layers
=
config
[
"n_layers"
],
dropout
=
dropout
,
bidirectional
=
config
[
"birnn"
])
self
.
n_cells
=
config
[
"n_cells"
]
self
.
d_hidden
=
config
[
"d_hidden"
]
self
.
birnn
=
config
[
"birnn"
]
def
forward
(
self
,
inputs
):
batch_size
=
inputs
.
size
()[
1
]
state_shape
=
self
.
config
.
n_cells
,
batch_size
,
self
.
config
.
d_hidden
state_shape
=
self
.
n_cells
,
batch_size
,
self
.
d_hidden
h0
=
c0
=
inputs
.
new_zeros
(
state_shape
)
outputs
,
(
ht
,
ct
)
=
self
.
rnn
(
inputs
,
(
h0
,
c0
))
return
ht
[
-
1
]
if
not
self
.
config
.
birnn
else
ht
[
-
2
:].
transpose
(
0
,
1
).
contiguous
().
view
(
batch_size
,
-
1
)
return
ht
[
-
1
]
if
not
self
.
birnn
else
ht
[
-
2
:].
transpose
(
0
,
1
).
contiguous
().
view
(
batch_size
,
-
1
)
class
SNLIClassifier
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
SNLIClassifier
,
self
).
__init__
()
self
.
config
=
config
self
.
embed
=
nn
.
Embedding
(
config
.
n_embed
,
config
.
d_embed
)
self
.
projection
=
Linear
(
config
.
d_embed
,
config
.
d_proj
)
self
.
embed
=
nn
.
Embedding
(
config
[
"n_embed"
],
config
[
"d_embed"
])
self
.
projection
=
Linear
(
config
[
"d_embed"
],
config
[
"d_proj"
])
self
.
encoder
=
Encoder
(
config
)
self
.
dropout
=
nn
.
Dropout
(
p
=
config
.
dp_ratio
)
self
.
dropout
=
nn
.
Dropout
(
p
=
config
[
"
dp_ratio
"
]
)
self
.
relu
=
nn
.
ReLU
()
seq_in_size
=
2
*
config
.
d_hidden
if
self
.
config
.
birnn
:
seq_in_size
=
2
*
config
[
"
d_hidden
"
]
if
config
[
"
birnn
"
]
:
seq_in_size
*=
2
lin_config
=
[
seq_in_size
]
*
2
self
.
out
=
nn
.
Sequential
(
...
...
@@ -310,15 +315,17 @@ class TestConvert(unittest.TestCase):
Linear
(
*
lin_config
),
self
.
relu
,
self
.
dropout
,
Linear
(
seq_in_size
,
config
.
d_out
))
Linear
(
seq_in_size
,
config
[
"d_out"
]))
self
.
fix_emb
=
config
[
"fix_emb"
]
self
.
project
=
config
[
"projection"
]
def
forward
(
self
,
premise
,
hypothesis
):
prem_embed
=
self
.
embed
(
premise
)
hypo_embed
=
self
.
embed
(
hypothesis
)
if
self
.
config
.
fix_emb
:
if
self
.
fix_emb
:
prem_embed
=
prem_embed
.
detach
()
hypo_embed
=
hypo_embed
.
detach
()
if
self
.
config
.
project
ion
:
if
self
.
project
:
prem_embed
=
self
.
relu
(
self
.
projection
(
prem_embed
))
hypo_embed
=
self
.
relu
(
self
.
projection
(
hypo_embed
))
premise
=
self
.
encoder
(
prem_embed
)
...
...
@@ -326,23 +333,24 @@ class TestConvert(unittest.TestCase):
scores
=
self
.
out
(
torch
.
cat
([
premise
,
hypothesis
],
1
))
return
scores
class
Config
:
n_embed
=
100
d_embed
=
100
d_proj
=
300
dp_ratio
=
0.0
# For deterministic testing TODO: change by fixing seed in checkTrace?
d_hidden
=
30
birnn
=
True
d_out
=
300
fix_emb
=
True
projection
=
True
n_layers
=
2
n_cells
=
4
# 2 * n_layers because birnn = True
Config
=
{
"n_embed"
:
100
,
"d_embed"
:
100
,
"d_proj"
:
300
,
"dp_ratio"
:
0.0
,
# For deterministic testing TOD": change by fixing seed in checkTrace?,
"d_hidden"
:
30
,
"birnn"
:
True
,
"d_out"
:
300
,
"fix_emb"
:
True
,
"projection"
:
True
,
"n_layers"
:
2
,
"n_cells"
:
4
# 2 * n_layers because birnn = True,
}
premise
=
torch
.
LongTensor
(
48
,
64
).
random_
(
0
,
100
)
hypothesis
=
torch
.
LongTensor
(
24
,
64
).
random_
(
0
,
100
)
self
.
checkExportImport
(
SNLIClassifier
(
Config
()
),
(
premise
,
hypothesis
))
self
.
checkExportImport
(
SNLIClassifier
(
Config
),
(
premise
,
hypothesis
))
def
test_super_resolution
(
self
):
class
Net
(
nn
.
Module
):
...
...
@@ -367,16 +375,16 @@ class TestConvert(unittest.TestCase):
net
=
Net
(
upscale_factor
=
4
)
self
.
checkExportImport
(
net
,
(
torch
.
rand
(
5
,
1
,
32
,
32
),))
@
unittest
.
skip
(
'Need to support op
erator prim::ListUnpack
'
)
# FIXME
@
unittest
.
skip
(
'Need to support
Lo
op'
)
# FIXME
def
test_time_sequence_prediction
(
self
):
class
Sequence
(
torch
.
jit
.
ScriptModule
):
class
Sequence
(
nn
.
Module
):
#
torch.jit.ScriptModule
def
__init__
(
self
):
super
(
Sequence
,
self
).
__init__
()
self
.
lstm1
=
nn
.
LSTMCell
(
1
,
51
)
self
.
lstm2
=
nn
.
LSTMCell
(
51
,
51
)
self
.
linear
=
nn
.
Linear
(
51
,
1
)
@
torch
.
jit
.
script_method
#
@torch.jit.script_method
def
forward
(
self
,
input
):
# TODO: add future as input with default val
# see https://github.com/pytorch/pytorch/issues/8724
...
...
@@ -414,7 +422,7 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
Traced
(),
(
torch
.
rand
(
3
,
4
),))
@
unittest
.
skip
(
'
Unsupported callmethod encode
'
)
# FIXME
@
unittest
.
skip
(
'
incorrectly assigned weights
'
)
# FIXME
def
test_vae
(
self
):
class
VAE
(
nn
.
Module
):
def
__init__
(
self
):
...
...
@@ -449,11 +457,11 @@ class TestConvert(unittest.TestCase):
self
.
checkExportImport
(
VAE
().
eval
(),
(
torch
.
rand
(
128
,
1
,
28
,
28
),))
@
unittest
.
skip
(
'torchvision models are not supported yet'
)
# FIXME
def
test_torchvision_resnet18
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
(),
(
torch
.
ones
(
1
,
3
,
224
,
224
),))
@
unittest
.
skip
(
'Unsupported CallMethod _forward_impl'
)
# FIXME
def
test_resnet
(
self
):
def
conv1x1
(
in_planes
,
out_planes
,
stride
=
1
):
"""1x1 convolution"""
...
...
@@ -464,7 +472,7 @@ class TestConvert(unittest.TestCase):
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
1
,
bias
=
False
)
class
BasicBlock
(
torch
.
jit
.
ScriptModule
):
class
BasicBlock
(
nn
.
Module
):
#
torch.jit.ScriptModule
expansion
=
1
__constants__
=
[
'downsample'
]
...
...
@@ -478,7 +486,8 @@ class TestConvert(unittest.TestCase):
self
.
downsample
=
downsample
self
.
stride
=
stride
@
torch
.
jit
.
script_method
# NOTE: jit cannot be annotated, otherwise, module id is not matched for recorded arguments
#@torch.jit.script_method
def
forward
(
self
,
x
):
residual
=
x
...
...
@@ -497,7 +506,8 @@ class TestConvert(unittest.TestCase):
return
out
class
ResNet
(
torch
.
jit
.
ScriptModule
):
# NOTE: cannot inherit torch.jit.ScriptModule, otherwise, there would be error: 'RecursiveScriptModule' object has no attribute 'graph'
class
ResNet
(
nn
.
Module
):
#torch.jit.ScriptModule
__constants__
=
[
'layer1'
,
'layer2'
,
'layer3'
,
'layer4'
]
def
__init__
(
self
,
block
,
layers
,
num_classes
=
1000
):
...
...
@@ -538,7 +548,8 @@ class TestConvert(unittest.TestCase):
return
nn
.
Sequential
(
*
layers
)
@
torch
.
jit
.
script_method
# NOTE: jit cannot be annotated, otherwise, module id is not matched for recorded arguments
#@torch.jit.script_method
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
...
...
@@ -558,10 +569,11 @@ class TestConvert(unittest.TestCase):
resnet18
=
ResNet
(
BasicBlock
,
[
2
,
2
,
2
,
2
])
self
.
checkExportImport
(
torchvision
.
models
.
resnet18
().
eval
()
,
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
self
.
checkExportImport
(
resnet18
,
(
torch
.
randn
(
1
,
3
,
224
,
224
),))
@
unittest
.
skip
(
'torchvision models are not supported yet'
)
# FIXME
def
test_alexnet
(
self
):
from
.inject_nn
import
inject_pytorch_nn
inject_pytorch_nn
()
x
=
torch
.
ones
(
1
,
3
,
224
,
224
)
model
=
torchvision
.
models
.
AlexNet
()
self
.
checkExportImport
(
model
,
(
x
,))
test/ut/retiarii/test_convert_basic.py
0 → 100644
View file @
58d5c2fa
import
os
import
sys
import
unittest
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
torchvision
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
blackbox_module
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.utils
import
get_records
# following pytorch v1.7.1
class
TestConvert
(
unittest
.
TestCase
):
@
staticmethod
def
_match_state_dict
(
current_values
,
expected_format
):
result
=
{}
for
k
,
v
in
expected_format
.
items
():
for
idx
,
cv
in
enumerate
(
current_values
):
if
cv
.
shape
==
v
.
shape
:
result
[
k
]
=
cv
current_values
.
pop
(
idx
)
break
return
result
def
checkExportImport
(
self
,
model
,
input
,
check_value
=
True
):
script_module
=
torch
.
jit
.
script
(
model
)
model_ir
=
convert_to_graph
(
script_module
,
model
)
model_code
=
model_to_pytorch_script
(
model_ir
)
print
(
model_code
)
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
converted_model
=
exec_vars
[
'converted_model'
]
converted_state_dict
=
self
.
_match_state_dict
(
list
(
model
.
state_dict
().
values
()),
dict
(
converted_model
.
state_dict
()))
converted_model
.
load_state_dict
(
converted_state_dict
)
with
torch
.
no_grad
():
expected_output
=
model
.
eval
()(
*
input
)
converted_output
=
converted_model
.
eval
()(
*
input
)
if
check_value
:
self
.
assertEqual
(
len
(
converted_output
),
len
(
expected_output
))
for
a
,
b
in
zip
(
converted_output
,
expected_output
):
if
hasattr
(
a
,
'dtype'
)
and
a
.
dtype
==
torch
.
bool
:
self
.
assertEqual
((
a
^
b
),
False
)
elif
isinstance
((
a
-
b
),
int
):
self
.
assertEqual
((
a
-
b
),
0
)
else
:
self
.
assertLess
((
a
-
b
).
abs
().
max
().
item
(),
1E-4
)
return
converted_model
# skip torch.Tensor.new_tensor as it is not supported by jit
def
test_basic_new_full
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
# requires_grad is not supported by jit
# aten::new_full(Tensor self, int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor):
# Keyword argument requires_grad unknown.
out
=
x
.
new_full
((
3
,
4
),
3.141592
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
'cpu'
))
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
ones
((
2
,),
dtype
=
torch
.
float64
),
))
def
test_basic_new_empty
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
new_empty
((
2
,
3
),
dtype
=
torch
.
int8
,
device
=
torch
.
device
(
'cpu'
))
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
ones
(()),
),
check_value
=
False
)
# skip torch.Tensor.new_ones as it is not supported by jit
# requires_grad=False is not supported by jit
def
test_basic_new_zeros
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out
=
x
.
new_zeros
((
2
,
3
))
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
((),
dtype
=
torch
.
int32
),
))
def
test_basic_is_cuda
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
torch
.
tensor
([
x
.
is_cuda
],
dtype
=
torch
.
bool
,
device
=
torch
.
device
(
'cpu'
))
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
((),
dtype
=
torch
.
int32
),
))
# is_quantized
# is_meta
# device
# grad
# ndim
# T
# real
# imag
def
test_basic_abs
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
abs
()
out11
=
x
.
absolute
()
out2
=
torch
.
abs
(
x
)
#out3 = x.abs_()
#out33 = x.absolute_()
return
out1
,
out11
,
out2
#, out3, out33
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1
,
-
2
,
3
]),
))
# TODO: topological sort should be improved
#def forward(self, x__1):
# __Acos2 = x__1.acos()
# __Acos_3 = x__1.acos_()
# __Acos1 = x__1.acos()
# __TupleConstruct4 = (__Acos1,__Acos2,__Acos_3)
# return __TupleConstruct4
def
test_basic_acos_asin_atan
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
acos
()
out2
=
torch
.
acos
(
x
)
# TODO: add back this line
#out = x.acos_()
out3
=
x
.
asin
()
out4
=
torch
.
asin
(
x
)
out5
=
x
.
atan
()
out6
=
torch
.
atan
(
x
)
out7
=
x
.
atan2
(
y
)
out8
=
torch
.
atan2
(
x
,
y
)
return
out1
,
out2
,
out3
,
out4
,
out5
,
out6
,
out7
,
out8
#, out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1.0
,
-
0.5
,
0.2
]),
torch
.
tensor
([
1.0
,
0.6
,
-
0.3
]),
))
# arccos is not supported by jit
def
test_basic_add
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
t
=
torch
.
tensor
([
-
1.0
,
-
0.5
,
0.2
])
out1
=
x
.
add
(
t
)
out2
=
x
.
add
(
t
,
alpha
=
2
)
#out3 = x.add_(t)
return
out1
,
out2
#, out3
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1.0
,
-
0.5
,
0.2
]),
))
def
test_basic_addbmm
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
,
m
):
out1
=
x
.
addbmm
(
y
,
z
,
beta
=
2
,
alpha
=
3
)
out2
=
torch
.
addbmm
(
x
,
y
,
z
,
beta
=
2
,
alpha
=
3
)
#out3 = x.addbmm_(y, z, beta=2, alpha=3)
out3
=
m
.
baddbmm
(
y
,
z
,
beta
=
2
,
alpha
=
3
)
out4
=
torch
.
baddbmm
(
m
,
y
,
z
,
beta
=
2
,
alpha
=
3
)
out5
=
torch
.
bmm
(
y
,
z
)
# deterministic is not supported by jit
return
out1
,
out2
,
out3
,
out4
,
out5
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
3
,
5
),
torch
.
randn
(
10
,
3
,
4
),
torch
.
randn
(
10
,
4
,
5
),
torch
.
randn
(
10
,
3
,
5
),
))
def
test_basic_addcdiv
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addcdiv
(
y
,
z
,
value
=
2
)
out2
=
torch
.
addcdiv
(
x
,
y
,
z
,
value
=
2
)
# addcdiv_
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
1
,
3
),
torch
.
randn
(
3
,
1
),
torch
.
randn
(
1
,
3
),
))
def
test_basic_addcmul
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addcmul
(
y
,
z
,
value
=
0.1
)
out2
=
torch
.
addcmul
(
x
,
y
,
z
,
value
=
0.1
)
# addcmul_
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
1
,
3
),
torch
.
randn
(
3
,
1
),
torch
.
randn
(
1
,
3
),
))
def
test_basic_addmm
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addmm
(
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
out2
=
torch
.
addmm
(
x
,
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
# addmm_
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
2
,
3
),
torch
.
randn
(
2
,
3
),
torch
.
randn
(
3
,
3
),
))
def
test_basic_addmv
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addmv
(
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
out2
=
torch
.
addmv
(
x
,
y
,
z
,
beta
=
0.1
,
alpha
=
0.2
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
2
),
torch
.
randn
(
2
,
3
),
torch
.
randn
(
3
),
))
def
test_basic_addr
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
,
z
):
out1
=
x
.
addr
(
y
,
z
,
beta
=
2
,
alpha
=
3
)
out2
=
torch
.
addr
(
x
,
y
,
z
,
beta
=
2
,
alpha
=
3
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
zeros
(
3
,
2
),
torch
.
arange
(
1.
,
4.
),
torch
.
arange
(
1.
,
3.
),
))
def
test_basic_allclose
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
allclose
(
y
,
rtol
=
1e-05
,
atol
=
1e-08
,
equal_nan
=
False
)
out2
=
torch
.
allclose
(
x
,
y
,
rtol
=
1e-05
,
atol
=
1e-08
,
equal_nan
=
False
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
10000.
,
1e-07
]),
torch
.
tensor
([
10000.1
,
1e-08
]),
))
def
test_basic_angle
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
angle
()
out2
=
torch
.
angle
(
x
)
return
out1
,
out2
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1
+
1j
,
-
2
+
2j
,
3
-
3j
]),
))
# skip apply_(callable) for now
def
test_basic_argmax_argmin
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
argmax
()
out2
=
torch
.
argmax
(
x
)
out3
=
x
.
argmax
(
dim
=
1
)
out4
=
torch
.
argmax
(
x
,
dim
=
1
)
out5
=
x
.
argmax
(
dim
=
1
,
keepdim
=
True
)
o1
=
x
.
argmin
()
o2
=
torch
.
argmin
(
x
)
o3
=
x
.
argmin
(
dim
=
1
)
o4
=
x
.
argmin
(
dim
=
1
,
keepdim
=
True
)
return
out1
,
out2
,
out3
,
out4
,
out5
,
o1
,
o2
,
o3
,
o4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
def
test_basic_argsort
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
argsort
()
out2
=
x
.
argsort
(
dim
=
1
)
out3
=
x
.
argsort
(
dim
=
1
,
descending
=
True
)
out4
=
torch
.
argsort
(
x
,
dim
=
1
,
descending
=
True
)
return
out1
,
out2
,
out3
,
out4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
,
4
),
))
# skip backward(gradient=None, retain_graph=None, create_graph=False)
def
test_basic_bernoulli
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
# generator=torch.Generator() is not supported by jit
out
=
x
.
bernoulli
()
return
out
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
ones
(
3
,
3
),
))
# bfloat16/bool/byte/char is not supported by jit
def
test_basic_bincount
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
bincount
()
out2
=
torch
.
bincount
(
x
)
out3
=
x
.
bincount
(
weights
=
y
)
out4
=
x
.
bincount
(
weights
=
y
,
minlength
=
2
)
return
out1
,
out2
,
out3
,
out4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randint
(
0
,
8
,
(
5
,),
dtype
=
torch
.
int64
),
torch
.
linspace
(
0
,
1
,
steps
=
5
),
))
def
test_basic_bitwise
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
,
y
):
out1
=
x
.
bitwise_not
()
out2
=
x
.
bitwise_and
(
y
)
out3
=
x
.
bitwise_or
(
y
)
out4
=
x
.
bitwise_xor
(
y
)
return
out1
,
out2
,
out3
,
out4
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
tensor
([
-
1
,
-
2
,
3
],
dtype
=
torch
.
int8
),
torch
.
tensor
([
1
,
0
,
3
],
dtype
=
torch
.
int8
),
))
# cauchy_ is not supported yet
def
test_ceil
(
self
):
class
SimpleOp
(
nn
.
Module
):
def
forward
(
self
,
x
):
out1
=
x
.
ceil
()
return
out1
self
.
checkExportImport
(
SimpleOp
(),
(
torch
.
randn
(
4
),
))
\ No newline at end of file
test/ut/retiarii/test_convert_models.py
0 → 100644
View file @
58d5c2fa
test/ut/retiarii/test_convert_operators.py
0 → 100644
View file @
58d5c2fa
This diff is collapsed.
Click to expand it.
test/ut/retiarii/test_convert_pytorch.py
0 → 100644
View file @
58d5c2fa
This diff is collapsed.
Click to expand it.
test/ut/retiarii/test_highlevel_apis.py
View file @
58d5c2fa
...
...
@@ -167,6 +167,7 @@ class TestHighLevelAPI(unittest.TestCase):
mutator
=
mutators
[
0
].
bind_sampler
(
EnuemrateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment