Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
0f88b86b
"testing/vscode:/vscode.git/clone" did not exist on "5f202fe5c1d63a5e3a1598690877eccff2ad4640"
Unverified
Commit
0f88b86b
authored
Jan 05, 2021
by
Yuge Zhang
Committed by
GitHub
Jan 05, 2021
Browse files
Retiarii graph and code generation test (#3231)
parent
4fae3ed9
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
582 additions
and
3 deletions
+582
-3
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+4
-0
nni/retiarii/operation.py
nni/retiarii/operation.py
+5
-2
nni/retiarii/utils.py
nni/retiarii/utils.py
+6
-1
test/ut/retiarii/test_convert.py
test/ut/retiarii/test_convert.py
+567
-0
No files found.
nni/retiarii/converter/op_types.py
View file @
0f88b86b
...
...
@@ -30,6 +30,10 @@ BasicOpsPT = {
'aten::size'
:
'Size'
,
'aten::view'
:
'View'
,
'aten::eq'
:
'Eq'
,
'aten::Bool'
:
'Bool'
,
'aten::empty'
:
'Empty'
,
'aten::zeros'
:
'Zeros'
,
'aten::chunk'
:
'Chunk'
,
'aten::add_'
:
'Add_'
# %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
...
...
nni/retiarii/operation.py
View file @
0f88b86b
...
...
@@ -121,6 +121,8 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
=
{
value
}
'
elif
self
.
type
==
'prim::ListConstruct'
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
elif
self
.
type
==
'prim::GetAttr'
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'input'
]
}
.
{
self
.
parameters
[
'name'
]
}
"
elif
self
.
type
==
'aten::mean'
:
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
elif
self
.
type
==
'aten::__getitem__'
:
...
...
@@ -133,8 +135,7 @@ class PyTorchOperation(Operation):
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::add'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
return
f
'
{
output
}
= '
+
' + '
.
join
(
inputs
)
elif
self
.
type
==
OpTypeName
.
MergedSlice
:
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
slices
=
[]
...
...
@@ -151,6 +152,8 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::slice'
:
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
elif
self
.
type
==
'aten::Bool'
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
...
...
nni/retiarii/utils.py
View file @
0f88b86b
...
...
@@ -27,6 +27,11 @@ def get_records():
return
_records
def
clear_records
():
global
_records
_records
=
{}
def
add_record
(
key
,
value
):
"""
"""
...
...
@@ -56,7 +61,7 @@ def _blackbox_cls(cls, module_name, register_format=None):
# eject un-serializable arguments
for
k
in
list
(
full_args
.
keys
()):
# The list is not complete and does not support nested cases.
if
not
isinstance
(
full_args
[
k
],
(
int
,
float
,
str
,
dict
,
list
)):
if
not
isinstance
(
full_args
[
k
],
(
int
,
float
,
str
,
dict
,
list
,
tuple
)):
if
not
(
register_format
==
'full'
and
k
==
'model'
):
# no warning if it is base model in trainer
warnings
.
warn
(
f
'
{
cls
}
has un-serializable arguments
{
k
}
whose value is
{
full_args
[
k
]
}
.
\
...
...
test/ut/retiarii/test_convert.py
0 → 100644
View file @
0f88b86b
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment