Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
85fd39a7
Unverified
Commit
85fd39a7
authored
Mar 05, 2021
by
Yuge Zhang
Committed by
GitHub
Mar 05, 2021
Browse files
[Retiarii] Serializer and experiment status fixes (#3421)
parent
d69d4ae9
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
38 additions
and
13 deletions
+38
-13
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+2
-2
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+11
-1
nni/retiarii/graph.py
nni/retiarii/graph.py
+2
-2
nni/retiarii/integration.py
nni/retiarii/integration.py
+3
-0
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+8
-3
nni/retiarii/utils.py
nni/retiarii/utils.py
+9
-5
test/ut/retiarii/test_serializer.py
test/ut/retiarii/test_serializer.py
+3
-0
No files found.
nni/retiarii/converter/graph_gen.py
View file @
85fd39a7
...
...
@@ -6,7 +6,7 @@ from ..graph import Graph, Model, Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..operation
import
Cell
,
Operation
from
..serializer
import
get_init_parameters_or_fail
from
..utils
import
get_
full_class
_name
from
..utils
import
get_
importable
_name
from
.op_types
import
MODULE_EXCEPT_LIST
,
OpTypeName
from
.utils
import
_convert_name
,
build_full_name
...
...
@@ -536,7 +536,7 @@ class GraphConverter:
def
_handle_layerchoice
(
self
,
module
):
choices
=
[]
for
cand
in
list
(
module
):
cand_type
=
'__torch__.'
+
get_
full_class
_name
(
cand
.
__class__
)
cand_type
=
'__torch__.'
+
get_
importable
_name
(
cand
.
__class__
)
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
get_init_parameters_or_fail
(
cand
)})
return
{
'candidates'
:
choices
,
...
...
nni/retiarii/experiment/pytorch.py
View file @
85fd39a7
import
logging
import
time
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
subprocess
import
Popen
...
...
@@ -92,6 +93,8 @@ class RetiariiExperiment(Experiment):
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
_strategy_thread
:
Optional
[
Thread
]
=
None
def
_start_strategy
(
self
):
try
:
script_module
=
torch
.
jit
.
script
(
self
.
base_model
)
...
...
@@ -110,8 +113,11 @@ class RetiariiExperiment(Experiment):
self
.
applied_mutators
=
mutators
_logger
.
info
(
'Starting strategy...'
)
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model_ir
,
self
.
applied_mutators
)).
start
()
# This is not intuitive and not friendly for debugging (setting breakpoints). Will refactor later.
self
.
_strategy_thread
=
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model_ir
,
self
.
applied_mutators
))
self
.
_strategy_thread
.
start
()
_logger
.
info
(
'Strategy started!'
)
Thread
(
target
=
self
.
_strategy_monitor
).
start
()
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
...
...
@@ -131,6 +137,10 @@ class RetiariiExperiment(Experiment):
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
def
_strategy_monitor
(
self
):
self
.
_strategy_thread
.
join
()
self
.
_dispatcher
.
mark_experiment_as_ending
()
def
run
(
self
,
config
:
RetiariiExeConfig
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
"""
Run the experiment.
...
...
nni/retiarii/graph.py
View file @
85fd39a7
...
...
@@ -9,7 +9,7 @@ from enum import Enum
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
from
.utils
import
get_
full_class
_name
,
import_
,
uid
from
.utils
import
get_
importable
_name
,
import_
,
uid
__all__
=
[
'Model'
,
'ModelStatus'
,
'Graph'
,
'Node'
,
'Edge'
,
'IllegalGraphError'
,
'MetricData'
]
...
...
@@ -147,7 +147,7 @@ class Model:
def
_dump
(
self
)
->
Any
:
ret
=
{
name
:
graph
.
_dump
()
for
name
,
graph
in
self
.
graphs
.
items
()}
ret
[
'_evaluator'
]
=
{
'__type__'
:
get_
full_class
_name
(
self
.
evaluator
.
__class__
),
'__type__'
:
get_
importable
_name
(
self
.
evaluator
.
__class__
),
**
self
.
evaluator
.
_dump
()
}
return
ret
...
...
nni/retiarii/integration.py
View file @
85fd39a7
...
...
@@ -105,6 +105,9 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
return
self
.
parameters_count
def
mark_experiment_as_ending
(
self
):
send
(
CommandType
.
NoMoreTrialJobs
,
''
)
def
handle_request_trial_jobs
(
self
,
num_trials
):
_logger
.
info
(
'Request trial jobs: %s'
,
num_trials
)
if
self
.
request_trial_jobs_callback
is
not
None
:
...
...
nni/retiarii/serializer.py
View file @
85fd39a7
import
abc
import
functools
import
inspect
import
types
from
typing
import
Any
import
json_tricks
from
.utils
import
get_
full_class
_name
,
get_module_name
,
import_
from
.utils
import
get_
importable
_name
,
get_module_name
,
import_
def
get_init_parameters_or_fail
(
obj
,
silently
=
False
):
...
...
@@ -29,7 +30,7 @@ def _serialize_class_instance_encode(obj, primitives=False):
try
:
# FIXME: raise error
if
hasattr
(
obj
,
'__class__'
):
return
{
'__type__'
:
get_
full_class
_name
(
obj
.
__class__
),
'__type__'
:
get_
importable
_name
(
obj
.
__class__
),
'arguments'
:
get_init_parameters_or_fail
(
obj
)
}
except
ValueError
:
...
...
@@ -46,7 +47,11 @@ def _serialize_class_instance_decode(obj):
def
_type_encode
(
obj
,
primitives
=
False
):
assert
not
primitives
,
'Encoding with primitives is not supported.'
if
isinstance
(
obj
,
type
):
return
{
'__typename__'
:
get_full_class_name
(
obj
,
relocate_module
=
True
)}
return
{
'__typename__'
:
get_importable_name
(
obj
,
relocate_module
=
True
)}
if
isinstance
(
obj
,
(
types
.
FunctionType
,
types
.
BuiltinFunctionType
)):
# This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized.
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return
{
'__typename__'
:
get_importable_name
(
obj
,
relocate_module
=
True
)}
return
obj
...
...
nni/retiarii/utils.py
View file @
85fd39a7
import
inspect
import
warnings
from
collections
import
defaultdict
from
typing
import
Any
from
pathlib
import
Path
...
...
@@ -27,8 +28,8 @@ def uid(namespace: str = 'default') -> int:
return
_last_uid
[
namespace
]
def
get_module_name
(
cls
):
module_name
=
cls
.
__module__
def
get_module_name
(
cls
_or_func
):
module_name
=
cls
_or_func
.
__module__
if
module_name
==
'__main__'
:
# infer the module name with inspect
for
frm
in
inspect
.
stack
():
...
...
@@ -40,17 +41,20 @@ def get_module_name(cls):
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
module_name
=
main_file_path
.
stem
break
if
module_name
==
'__main__'
:
warnings
.
warn
(
'Callstack exhausted but main module still not found. This will probably cause issues that the '
'function/class cannot be imported.'
)
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if
f
'
{
cls
.
__module__
}
.
{
cls
.
__name__
}
'
==
'torch.nn.modules.rnn.LSTM'
:
module_name
=
cls
.
__module__
if
f
'
{
cls
_or_func
.
__module__
}
.
{
cls
_or_func
.
__name__
}
'
==
'torch.nn.modules.rnn.LSTM'
:
module_name
=
cls
_or_func
.
__module__
return
module_name
def
get_
full_class
_name
(
cls
,
relocate_module
=
False
):
def
get_
importable
_name
(
cls
,
relocate_module
=
False
):
module_name
=
get_module_name
(
cls
)
if
relocate_module
else
cls
.
__module__
return
module_name
+
'.'
+
cls
.
__name__
test/ut/retiarii/test_serializer.py
View file @
85fd39a7
import
json
import
math
from
pathlib
import
Path
import
re
import
sys
...
...
@@ -84,6 +85,8 @@ def test_type():
assert
json_dumps
(
torch
.
optim
.
Adam
)
==
'{"__typename__": "torch.optim.adam.Adam"}'
assert
json_loads
(
'{"__typename__": "torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
assert
re
.
match
(
r
'{"__typename__": "(.*)test_serializer.Foo"}'
,
json_dumps
(
Foo
))
assert
json_dumps
(
math
.
floor
)
==
'{"__typename__": "math.floor"}'
assert
json_loads
(
'{"__typename__": "math.floor"}'
)
==
math
.
floor
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment