Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
85fd39a7
Unverified
Commit
85fd39a7
authored
Mar 05, 2021
by
Yuge Zhang
Committed by
GitHub
Mar 05, 2021
Browse files
[Retiarii] Serializer and experiment status fixes (#3421)
parent
d69d4ae9
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
38 additions
and
13 deletions
+38
-13
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+2
-2
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+11
-1
nni/retiarii/graph.py
nni/retiarii/graph.py
+2
-2
nni/retiarii/integration.py
nni/retiarii/integration.py
+3
-0
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+8
-3
nni/retiarii/utils.py
nni/retiarii/utils.py
+9
-5
test/ut/retiarii/test_serializer.py
test/ut/retiarii/test_serializer.py
+3
-0
No files found.
nni/retiarii/converter/graph_gen.py
View file @
85fd39a7
...
@@ -6,7 +6,7 @@ from ..graph import Graph, Model, Node
...
@@ -6,7 +6,7 @@ from ..graph import Graph, Model, Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..operation
import
Cell
,
Operation
from
..operation
import
Cell
,
Operation
from
..serializer
import
get_init_parameters_or_fail
from
..serializer
import
get_init_parameters_or_fail
from
..utils
import
get_
full_class
_name
from
..utils
import
get_
importable
_name
from
.op_types
import
MODULE_EXCEPT_LIST
,
OpTypeName
from
.op_types
import
MODULE_EXCEPT_LIST
,
OpTypeName
from
.utils
import
_convert_name
,
build_full_name
from
.utils
import
_convert_name
,
build_full_name
...
@@ -536,7 +536,7 @@ class GraphConverter:
...
@@ -536,7 +536,7 @@ class GraphConverter:
def
_handle_layerchoice
(
self
,
module
):
def
_handle_layerchoice
(
self
,
module
):
choices
=
[]
choices
=
[]
for
cand
in
list
(
module
):
for
cand
in
list
(
module
):
cand_type
=
'__torch__.'
+
get_
full_class
_name
(
cand
.
__class__
)
cand_type
=
'__torch__.'
+
get_
importable
_name
(
cand
.
__class__
)
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
get_init_parameters_or_fail
(
cand
)})
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
get_init_parameters_or_fail
(
cand
)})
return
{
return
{
'candidates'
:
choices
,
'candidates'
:
choices
,
...
...
nni/retiarii/experiment/pytorch.py
View file @
85fd39a7
import
logging
import
logging
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
pathlib
import
Path
from
subprocess
import
Popen
from
subprocess
import
Popen
...
@@ -92,6 +93,8 @@ class RetiariiExperiment(Experiment):
...
@@ -92,6 +93,8 @@ class RetiariiExperiment(Experiment):
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
_strategy_thread
:
Optional
[
Thread
]
=
None
def
_start_strategy
(
self
):
def
_start_strategy
(
self
):
try
:
try
:
script_module
=
torch
.
jit
.
script
(
self
.
base_model
)
script_module
=
torch
.
jit
.
script
(
self
.
base_model
)
...
@@ -110,8 +113,11 @@ class RetiariiExperiment(Experiment):
...
@@ -110,8 +113,11 @@ class RetiariiExperiment(Experiment):
self
.
applied_mutators
=
mutators
self
.
applied_mutators
=
mutators
_logger
.
info
(
'Starting strategy...'
)
_logger
.
info
(
'Starting strategy...'
)
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model_ir
,
self
.
applied_mutators
)).
start
()
# This is not intuitive and not friendly for debugging (setting breakpoints). Will refactor later.
self
.
_strategy_thread
=
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model_ir
,
self
.
applied_mutators
))
self
.
_strategy_thread
.
start
()
_logger
.
info
(
'Strategy started!'
)
_logger
.
info
(
'Strategy started!'
)
Thread
(
target
=
self
.
_strategy_monitor
).
start
()
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
"""
...
@@ -131,6 +137,10 @@ class RetiariiExperiment(Experiment):
...
@@ -131,6 +137,10 @@ class RetiariiExperiment(Experiment):
def
_create_dispatcher
(
self
):
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
return
self
.
_dispatcher
def
_strategy_monitor
(
self
):
self
.
_strategy_thread
.
join
()
self
.
_dispatcher
.
mark_experiment_as_ending
()
def
run
(
self
,
config
:
RetiariiExeConfig
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
def
run
(
self
,
config
:
RetiariiExeConfig
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
"""
"""
Run the experiment.
Run the experiment.
...
...
nni/retiarii/graph.py
View file @
85fd39a7
...
@@ -9,7 +9,7 @@ from enum import Enum
...
@@ -9,7 +9,7 @@ from enum import Enum
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
from
.utils
import
get_
full_class
_name
,
import_
,
uid
from
.utils
import
get_
importable
_name
,
import_
,
uid
__all__
=
[
'Model'
,
'ModelStatus'
,
'Graph'
,
'Node'
,
'Edge'
,
'IllegalGraphError'
,
'MetricData'
]
__all__
=
[
'Model'
,
'ModelStatus'
,
'Graph'
,
'Node'
,
'Edge'
,
'IllegalGraphError'
,
'MetricData'
]
...
@@ -147,7 +147,7 @@ class Model:
...
@@ -147,7 +147,7 @@ class Model:
def
_dump
(
self
)
->
Any
:
def
_dump
(
self
)
->
Any
:
ret
=
{
name
:
graph
.
_dump
()
for
name
,
graph
in
self
.
graphs
.
items
()}
ret
=
{
name
:
graph
.
_dump
()
for
name
,
graph
in
self
.
graphs
.
items
()}
ret
[
'_evaluator'
]
=
{
ret
[
'_evaluator'
]
=
{
'__type__'
:
get_
full_class
_name
(
self
.
evaluator
.
__class__
),
'__type__'
:
get_
importable
_name
(
self
.
evaluator
.
__class__
),
**
self
.
evaluator
.
_dump
()
**
self
.
evaluator
.
_dump
()
}
}
return
ret
return
ret
...
...
nni/retiarii/integration.py
View file @
85fd39a7
...
@@ -105,6 +105,9 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -105,6 +105,9 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
return
self
.
parameters_count
return
self
.
parameters_count
def
mark_experiment_as_ending
(
self
):
send
(
CommandType
.
NoMoreTrialJobs
,
''
)
def
handle_request_trial_jobs
(
self
,
num_trials
):
def
handle_request_trial_jobs
(
self
,
num_trials
):
_logger
.
info
(
'Request trial jobs: %s'
,
num_trials
)
_logger
.
info
(
'Request trial jobs: %s'
,
num_trials
)
if
self
.
request_trial_jobs_callback
is
not
None
:
if
self
.
request_trial_jobs_callback
is
not
None
:
...
...
nni/retiarii/serializer.py
View file @
85fd39a7
import
abc
import
abc
import
functools
import
functools
import
inspect
import
inspect
import
types
from
typing
import
Any
from
typing
import
Any
import
json_tricks
import
json_tricks
from
.utils
import
get_
full_class
_name
,
get_module_name
,
import_
from
.utils
import
get_
importable
_name
,
get_module_name
,
import_
def
get_init_parameters_or_fail
(
obj
,
silently
=
False
):
def
get_init_parameters_or_fail
(
obj
,
silently
=
False
):
...
@@ -29,7 +30,7 @@ def _serialize_class_instance_encode(obj, primitives=False):
...
@@ -29,7 +30,7 @@ def _serialize_class_instance_encode(obj, primitives=False):
try
:
# FIXME: raise error
try
:
# FIXME: raise error
if
hasattr
(
obj
,
'__class__'
):
if
hasattr
(
obj
,
'__class__'
):
return
{
return
{
'__type__'
:
get_
full_class
_name
(
obj
.
__class__
),
'__type__'
:
get_
importable
_name
(
obj
.
__class__
),
'arguments'
:
get_init_parameters_or_fail
(
obj
)
'arguments'
:
get_init_parameters_or_fail
(
obj
)
}
}
except
ValueError
:
except
ValueError
:
...
@@ -46,7 +47,11 @@ def _serialize_class_instance_decode(obj):
...
@@ -46,7 +47,11 @@ def _serialize_class_instance_decode(obj):
def
_type_encode
(
obj
,
primitives
=
False
):
def
_type_encode
(
obj
,
primitives
=
False
):
assert
not
primitives
,
'Encoding with primitives is not supported.'
assert
not
primitives
,
'Encoding with primitives is not supported.'
if
isinstance
(
obj
,
type
):
if
isinstance
(
obj
,
type
):
return
{
'__typename__'
:
get_full_class_name
(
obj
,
relocate_module
=
True
)}
return
{
'__typename__'
:
get_importable_name
(
obj
,
relocate_module
=
True
)}
if
isinstance
(
obj
,
(
types
.
FunctionType
,
types
.
BuiltinFunctionType
)):
# This is not reliable for cases like closure, `open`, or objects that is callable but not intended to be serialized.
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return
{
'__typename__'
:
get_importable_name
(
obj
,
relocate_module
=
True
)}
return
obj
return
obj
...
...
nni/retiarii/utils.py
View file @
85fd39a7
import
inspect
import
inspect
import
warnings
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Any
from
typing
import
Any
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -27,8 +28,8 @@ def uid(namespace: str = 'default') -> int:
...
@@ -27,8 +28,8 @@ def uid(namespace: str = 'default') -> int:
return
_last_uid
[
namespace
]
return
_last_uid
[
namespace
]
def
get_module_name
(
cls
):
def
get_module_name
(
cls
_or_func
):
module_name
=
cls
.
__module__
module_name
=
cls
_or_func
.
__module__
if
module_name
==
'__main__'
:
if
module_name
==
'__main__'
:
# infer the module name with inspect
# infer the module name with inspect
for
frm
in
inspect
.
stack
():
for
frm
in
inspect
.
stack
():
...
@@ -40,17 +41,20 @@ def get_module_name(cls):
...
@@ -40,17 +41,20 @@ def get_module_name(cls):
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
f
'please launch the experiment under the directory where "
{
main_file_path
.
name
}
" is located.'
)
module_name
=
main_file_path
.
stem
module_name
=
main_file_path
.
stem
break
break
if
module_name
==
'__main__'
:
warnings
.
warn
(
'Callstack exhausted but main module still not found. This will probably cause issues that the '
'function/class cannot be imported.'
)
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
# TODO: find out all the modules that have the same requirement as LSTM
if
f
'
{
cls
.
__module__
}
.
{
cls
.
__name__
}
'
==
'torch.nn.modules.rnn.LSTM'
:
if
f
'
{
cls
_or_func
.
__module__
}
.
{
cls
_or_func
.
__name__
}
'
==
'torch.nn.modules.rnn.LSTM'
:
module_name
=
cls
.
__module__
module_name
=
cls
_or_func
.
__module__
return
module_name
return
module_name
def
get_
full_class
_name
(
cls
,
relocate_module
=
False
):
def
get_
importable
_name
(
cls
,
relocate_module
=
False
):
module_name
=
get_module_name
(
cls
)
if
relocate_module
else
cls
.
__module__
module_name
=
get_module_name
(
cls
)
if
relocate_module
else
cls
.
__module__
return
module_name
+
'.'
+
cls
.
__name__
return
module_name
+
'.'
+
cls
.
__name__
test/ut/retiarii/test_serializer.py
View file @
85fd39a7
import
json
import
json
import
math
from
pathlib
import
Path
from
pathlib
import
Path
import
re
import
re
import
sys
import
sys
...
@@ -84,6 +85,8 @@ def test_type():
...
@@ -84,6 +85,8 @@ def test_type():
assert
json_dumps
(
torch
.
optim
.
Adam
)
==
'{"__typename__": "torch.optim.adam.Adam"}'
assert
json_dumps
(
torch
.
optim
.
Adam
)
==
'{"__typename__": "torch.optim.adam.Adam"}'
assert
json_loads
(
'{"__typename__": "torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
assert
json_loads
(
'{"__typename__": "torch.optim.adam.Adam"}'
)
==
torch
.
optim
.
Adam
assert
re
.
match
(
r
'{"__typename__": "(.*)test_serializer.Foo"}'
,
json_dumps
(
Foo
))
assert
re
.
match
(
r
'{"__typename__": "(.*)test_serializer.Foo"}'
,
json_dumps
(
Foo
))
assert
json_dumps
(
math
.
floor
)
==
'{"__typename__": "math.floor"}'
assert
json_loads
(
'{"__typename__": "math.floor"}'
)
==
math
.
floor
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment