Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f84647fd
Unverified
Commit
f84647fd
authored
May 07, 2022
by
Yuge Zhang
Committed by
GitHub
May 07, 2022
Browse files
Remove Generic from SerializableObject typings (#4844)
parent
bef663b8
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
53 additions
and
6 deletions
+53
-6
nni/common/serializer.py
nni/common/serializer.py
+6
-6
test/.gitignore
test/.gitignore
+1
-0
test/ut/sdk/imported/_test_serializer_main.py
test/ut/sdk/imported/_test_serializer_main.py
+22
-0
test/ut/sdk/test_serializer.py
test/ut/sdk/test_serializer.py
+24
-0
No files found.
nni/common/serializer.py
View file @
f84647fd
...
@@ -13,7 +13,7 @@ import sys
...
@@ -13,7 +13,7 @@ import sys
import
types
import
types
import
warnings
import
warnings
from
io
import
IOBase
from
io
import
IOBase
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
,
Generic
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
TypeVar
,
Union
,
cast
import
cloudpickle
# use cloudpickle as backend for unserializable types and instances
import
cloudpickle
# use cloudpickle as backend for unserializable types and instances
import
json_tricks
# use json_tricks as serializer backend
import
json_tricks
# use json_tricks as serializer backend
...
@@ -115,13 +115,13 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool:
...
@@ -115,13 +115,13 @@ def is_wrapped_with_trace(cls_or_func: Any) -> bool:
)
)
class
SerializableObject
(
Generic
[
T
],
Traceable
):
class
SerializableObject
(
Traceable
):
# should be (Generic[T], Traceable), but cloudpickle is unhappy with Generic.
"""
"""
Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Serializable object is a wrapper of existing python objects, that supports dump and load easily.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
Stores a symbol ``s`` and a dict of arguments ``args``, and the object can be restored with ``s(**args)``.
"""
"""
def
__init__
(
self
,
symbol
:
T
,
args
:
List
[
Any
],
kwargs
:
Dict
[
str
,
Any
],
call_super
:
bool
=
False
):
def
__init__
(
self
,
symbol
:
T
ype
,
args
:
List
[
Any
],
kwargs
:
Dict
[
str
,
Any
],
call_super
:
bool
=
False
):
# use dict to avoid conflicts with user's getattr and setattr
# use dict to avoid conflicts with user's getattr and setattr
self
.
__dict__
[
'_nni_symbol'
]
=
symbol
self
.
__dict__
[
'_nni_symbol'
]
=
symbol
self
.
__dict__
[
'_nni_args'
]
=
args
self
.
__dict__
[
'_nni_args'
]
=
args
...
@@ -135,19 +135,19 @@ class SerializableObject(Generic[T], Traceable):
...
@@ -135,19 +135,19 @@ class SerializableObject(Generic[T], Traceable):
**
{
kw
:
_argument_processor
(
arg
)
for
kw
,
arg
in
kwargs
.
items
()}
**
{
kw
:
_argument_processor
(
arg
)
for
kw
,
arg
in
kwargs
.
items
()}
)
)
def
trace_copy
(
self
)
->
Union
[
T
,
'SerializableObject'
]
:
def
trace_copy
(
self
)
->
'SerializableObject'
:
return
SerializableObject
(
return
SerializableObject
(
self
.
trace_symbol
,
self
.
trace_symbol
,
[
copy
.
copy
(
arg
)
for
arg
in
self
.
trace_args
],
[
copy
.
copy
(
arg
)
for
arg
in
self
.
trace_args
],
{
k
:
copy
.
copy
(
v
)
for
k
,
v
in
self
.
trace_kwargs
.
items
()},
{
k
:
copy
.
copy
(
v
)
for
k
,
v
in
self
.
trace_kwargs
.
items
()},
)
)
def
get
(
self
)
->
T
:
def
get
(
self
)
->
Any
:
if
not
self
.
_get_nni_attr
(
'call_super'
):
if
not
self
.
_get_nni_attr
(
'call_super'
):
# Reinitialize
# Reinitialize
return
trace
(
self
.
trace_symbol
)(
*
self
.
trace_args
,
**
self
.
trace_kwargs
)
return
trace
(
self
.
trace_symbol
)(
*
self
.
trace_args
,
**
self
.
trace_kwargs
)
return
cast
(
T
,
self
)
return
self
@
property
@
property
def
trace_symbol
(
self
)
->
Any
:
def
trace_symbol
(
self
)
->
Any
:
...
...
test/.gitignore
View file @
f84647fd
...
@@ -3,6 +3,7 @@ __pycache__
...
@@ -3,6 +3,7 @@ __pycache__
tuner_search_space.json
tuner_search_space.json
tuner_result.txt
tuner_result.txt
assessor_result.txt
assessor_result.txt
serialize_result.txt
_generated_model.py
_generated_model.py
_generated_model_*.py
_generated_model_*.py
...
...
test/ut/sdk/imported/_test_serializer_main.py
0 → 100644
View file @
f84647fd
import
sys
import
torch.nn
as
nn
# sys.argv[1] == 0 -> dump
# sys.argv[1] == 1 -> load
import
nni
from
nni.retiarii
import
model_wrapper
@
model_wrapper
class
Net
(
nn
.
Module
):
something
=
1
import
cloudpickle
if
sys
.
argv
[
1
]
==
'0'
:
cloudpickle
.
dump
(
Net
,
open
(
'serialize_result.txt'
,
'wb'
))
# nni.dump(Net, fp=open('serialize_result.txt', 'w'))
else
:
obj
=
cloudpickle
.
load
(
open
(
'serialize_result.txt'
,
'rb'
))
# obj = nni.load(fp=open('serialize_result.txt'))
assert
obj
().
something
==
1
test/ut/sdk/test_serializer.py
View file @
f84647fd
import
math
import
math
import
os
import
pickle
import
pickle
import
subprocess
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
import
pytest
import
pytest
import
nni
import
nni
import
torch
import
torch
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
torchvision
import
transforms
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
from
torchvision.datasets
import
MNIST
...
@@ -375,3 +378,24 @@ def test_get():
...
@@ -375,3 +378,24 @@ def test_get():
obj2
=
obj1
.
trace_copy
()
obj2
=
obj1
.
trace_copy
()
obj2
.
trace_kwargs
[
'a'
]
=
-
1
obj2
.
trace_kwargs
[
'a'
]
=
-
1
assert
obj2
.
get
().
bar
()
==
0
assert
obj2
.
get
().
bar
()
==
0
def
test_model_wrapper_serialize
():
from
nni.retiarii
import
model_wrapper
@
model_wrapper
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
):
super
().
__init__
()
self
.
in_channels
=
in_channels
model
=
Model
(
3
)
dumped
=
nni
.
dump
(
model
)
loaded
=
nni
.
load
(
dumped
)
assert
loaded
.
in_channels
==
3
def
test_model_wrapper_across_process
():
main_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'imported'
,
'_test_serializer_main.py'
)
subprocess
.
run
([
sys
.
executable
,
main_file
,
'0'
],
check
=
True
)
subprocess
.
run
([
sys
.
executable
,
main_file
,
'1'
],
check
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment