Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
85fd39a7
Unverified
Commit
85fd39a7
authored
Mar 05, 2021
by
Yuge Zhang
Committed by
GitHub
Mar 05, 2021
Browse files
[Retiarii] Serializer and experiment status fixes (#3421)
parent
d69d4ae9
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
38 additions
and
13 deletions
+38
-13
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+2
-2
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+11
-1
nni/retiarii/graph.py
nni/retiarii/graph.py
+2
-2
nni/retiarii/integration.py
nni/retiarii/integration.py
+3
-0
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+8
-3
nni/retiarii/utils.py
nni/retiarii/utils.py
+9
-5
test/ut/retiarii/test_serializer.py
test/ut/retiarii/test_serializer.py
+3
-0
No files found.
nni/retiarii/converter/graph_gen.py
View file @
85fd39a7
...
...
@@ -6,7 +6,7 @@ from ..graph import Graph, Model, Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..operation
import
Cell
,
Operation
from
..serializer
import
get_init_parameters_or_fail
from
..utils
import
get_
full_class
_name
from
..utils
import
get_
importable
_name
from
.op_types
import
MODULE_EXCEPT_LIST
,
OpTypeName
from
.utils
import
_convert_name
,
build_full_name
...
...
@@ -536,7 +536,7 @@ class GraphConverter:
def
_handle_layerchoice
(
self
,
module
):
choices
=
[]
for
cand
in
list
(
module
):
cand_type
=
'__torch__.'
+
get_
full_class
_name
(
cand
.
__class__
)
cand_type
=
'__torch__.'
+
get_
importable
_name
(
cand
.
__class__
)
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
get_init_parameters_or_fail
(
cand
)})
return
{
'candidates'
:
choices
,
...
...
nni/retiarii/experiment/pytorch.py
View file @
85fd39a7
import
logging
import
time
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
subprocess
import
Popen
...
...
@@ -92,6 +93,8 @@ class RetiariiExperiment(Experiment):
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
_strategy_thread
:
Optional
[
Thread
]
=
None
def
_start_strategy
(
self
):
try
:
script_module
=
torch
.
jit
.
script
(
self
.
base_model
)
...
...
@@ -110,8 +113,11 @@ class RetiariiExperiment(Experiment):
self
.
applied_mutators
=
mutators
_logger
.
info
(
'Starting strategy...'
)
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model_ir
,
self
.
applied_mutators
)).
start
()
# This is not intuitive and not friendly for debugging (setting breakpoints). Will refactor later.
self
.
_strategy_thread
=
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model_ir
,
self
.
applied_mutators
))
self
.
_strategy_thread
.
start
()
_logger
.
info
(
'Strategy started!'
)
Thread
(
target
=
self
.
_strategy_monitor
).
start
()
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
...
...
@@ -131,6 +137,10 @@ class RetiariiExperiment(Experiment):
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
def
_strategy_monitor
(
self
):
self
.
_strategy_thread
.
join
()
self
.
_dispatcher
.
mark_experiment_as_ending
()
def
run
(
self
,
config
:
RetiariiExeConfig
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
"""
Run the experiment.
...
...
nni/retiarii/graph.py
View file @
85fd39a7
...
...
@@ -9,7 +9,7 @@ from enum import Enum
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
from
.utils
import
get_
full_class
_name
,
import_
,
uid
from
.utils
import
get_
importable
_name
,
import_
,
uid
__all__
=
[
'Model'
,
'ModelStatus'
,
'Graph'
,
'Node'
,
'Edge'
,
'IllegalGraphError'
,
'MetricData'
]
...
...
@@ -147,7 +147,7 @@ class Model:
def
_dump
(
self
)
->
Any
:
ret
=
{
name
:
graph
.
_dump
()
for
name
,
graph
in
self
.
graphs
.
items
()}
ret
[
'_evaluator'
]
=
{
'__type__'
:
get_
full_class
_name
(
self
.
evaluator
.
__class__
),
'__type__'
:
get_
importable
_name
(
self
.
evaluator
.
__class__
),
**
self
.
evaluator
.
_dump
()
}
return
ret
...
...
nni/retiarii/integration.py
View file @
85fd39a7
...
...
@@ -105,6 +105,9 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
return
self
.
parameters_count
def
mark_experiment_as_ending
(
self
):
send
(
CommandType
.
NoMoreTrialJobs
,
''
)
def
handle_request_trial_jobs
(
self
,
num_trials
):
_logger
.
info
(
'Request trial jobs: %s'
,
num_trials
)
if
self
.
request_trial_jobs_callback
is
not
None
:
...
...
nni/retiarii/serializer.py
View file @
85fd39a7
import
abc
import
functools
import
inspect
import
types
from
typing
import
Any
import
json_tricks
from
.utils
import
get_
full_class
_name
,
get_module_name
,
import_
from
.utils
import
get_
importable
_name
,
get_module_name
,
import_
def
get_init_parameters_or_fail
(
obj
,
silently
=
False
):
...
...
@@ -29,7 +30,7 @@ def _serialize_class_instance_encode(obj, primitives=False):
try
:
# FIXME: raise error
if
hasattr
(
obj
,
'__class__'
):
return
{
'__type__'
:
get_
full_class
_name
(
obj
.
__class__
),
'__type__'
:
get_
importable
_name
(
obj
.
__class__
),
'arguments'
:
get_init_parameters_or_fail
(
obj
)
}
except
ValueError
:
...
...
@@ -46,7 +47,11 @@ def _serialize_class_instance_decode(obj):
def
_type_encode
(
obj
,
primitives
=
False
):
assert
not
primitives
,
'Encoding with primitives is not supported.'
if
isinstance
(
obj
,
type
):
return
{
'__typename__'
:
get_full_class_name
(
obj
,
relocate_module
=
True
)}
return
{
'__typename__'
:
get_importable_name
(
obj
,
relocate_module
=
True
)}
if
isinstance
(
obj
,
(
types
.
FunctionType
,
types
.
BuiltinFunctionType
)):
# This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized.
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return
{
'__typename__'
:
get_importable_name
(
obj
,
relocate_module
=
True
)}
return
obj
...
...
nni/retiarii/utils.py
View file @
85fd39a7
import
inspect
import
warnings
from
collections
import
defaultdict
from
typing
import
Any
from
pathlib
import
Path
...
...
@@ -27,8 +28,8 @@ def uid(namespace: str = 'default') -> int:
return
_last_uid
[
namespace
]
def
get_module_name
(
cls
):
module_name
=
cls
.
__module__
def
get_module_name
(
cls
_or_func
):
module_name
=
cls
_or_func
.
__module__
if
module_name
==
'__main__'
:
# infer the module name with inspect
for
frm
in
inspect
.
stack
():
...
...
@@ -40,17 +41,20 @@ def get_module_name(cls):
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
module_name
=
main_file_path
.
stem
break
if
module_name
==
'__main__'
:
warnings
.
warn
(
'Callstack exhausted but main module still not found. This will probably cause issues that the '
'function/class cannot be imported.'
)
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if
f
'
{
cls
.
__module__
}
.
{
cls
.
__name__
}
'
==
'torch.nn.modules.rnn.LSTM'
:
module_name
=
cls
.
__module__
if
f
'
{
cls
_or_func
.
__module__
}
.
{
cls
_or_func
.
__name__
}
'
==
'torch.nn.modules.rnn.LSTM'
:
module_name
=
cls
_or_func
.
__module__
return
module_name
def
get_
full_class
_name
(
cls
,
relocate_module
=
False
):
def
get_
importable
_name
(
cls
,
relocate_module
=
False
):
module_name
=
get_module_name
(
cls
)
if
relocate_module
else
cls
.
__module__
return
module_name
+
'.'
+
cls
.
__name__
test/ut/retiarii/test_serializer.py
View file @
85fd39a7
import
json
import
math
from
pathlib
import
Path
import
re
import
sys
...
...
@@ -84,6 +85,8 @@ def test_type():
assert
json_dumps
(
torch
.
optim
.
Adam
)
==
'{"__typename__": "torch.optim.adam.Adam"}'
assert
json_loads
(
'{"__typename__": "torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
assert
re
.
match
(
r
'{"__typename__": "(.*)test_serializer.Foo"}'
,
json_dumps
(
Foo
))
assert
json_dumps
(
math
.
floor
)
==
'{"__typename__": "math.floor"}'
assert
json_loads
(
'{"__typename__": "math.floor"}'
)
==
math
.
floor
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment