Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
357ec6ef
Unverified
Commit
357ec6ef
authored
Dec 10, 2021
by
QuanluZhang
Committed by
GitHub
Dec 10, 2021
Browse files
[retiarii] support visualize model space with the hpo chart on webui (#4304)
parent
844f670b
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
81 additions
and
27 deletions
+81
-27
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+18
-5
nni/retiarii/execution/benchmark.py
nni/retiarii/execution/benchmark.py
+1
-1
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+1
-1
nni/retiarii/execution/python.py
nni/retiarii/execution/python.py
+5
-12
nni/retiarii/execution/utils.py
nni/retiarii/execution/utils.py
+25
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+4
-1
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+0
-1
nni/retiarii/strategy/local_debug_strategy.py
nni/retiarii/strategy/local_debug_strategy.py
+3
-1
nni/retiarii/strategy/utils.py
nni/retiarii/strategy/utils.py
+10
-0
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+1
-1
ts/nni_manager/core/nnimanager.ts
ts/nni_manager/core/nnimanager.ts
+2
-1
ts/webui/src/static/interface.ts
ts/webui/src/static/interface.ts
+6
-1
ts/webui/src/static/model/trial.ts
ts/webui/src/static/model/trial.ts
+5
-2
No files found.
nni/retiarii/execution/base.py
View file @
357ec6ef
...
@@ -8,27 +8,39 @@ import string
...
@@ -8,27 +8,39 @@ import string
from
typing
import
Any
,
Dict
,
Iterable
,
List
from
typing
import
Any
,
Dict
,
Iterable
,
List
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
from
.utils
import
get_mutation_summary
from
..
import
codegen
,
utils
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
,
Evaluator
from
..graph
import
Model
,
ModelStatus
,
MetricData
,
Evaluator
from
..integration_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
..integration_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
class
BaseGraphData
:
class
BaseGraphData
:
def
__init__
(
self
,
model_script
:
str
,
evaluator
:
Evaluator
)
->
None
:
"""
Attributes
----------
model_script
code of an instantiated PyTorch model
evaluator
training approach for model_script
mutation_summary
a dict of all the choices during mutations in the HPO search space format
"""
def
__init__
(
self
,
model_script
:
str
,
evaluator
:
Evaluator
,
mutation_summary
:
dict
)
->
None
:
self
.
model_script
=
model_script
self
.
model_script
=
model_script
self
.
evaluator
=
evaluator
self
.
evaluator
=
evaluator
self
.
mutation_summary
=
mutation_summary
def
dump
(
self
)
->
dict
:
def
dump
(
self
)
->
dict
:
return
{
return
{
'model_script'
:
self
.
model_script
,
'model_script'
:
self
.
model_script
,
'evaluator'
:
self
.
evaluator
'evaluator'
:
self
.
evaluator
,
'mutation_summary'
:
self
.
mutation_summary
}
}
@
staticmethod
@
staticmethod
def
load
(
data
)
->
'BaseGraphData'
:
def
load
(
data
)
->
'BaseGraphData'
:
return
BaseGraphData
(
data
[
'model_script'
],
data
[
'evaluator'
])
return
BaseGraphData
(
data
[
'model_script'
],
data
[
'evaluator'
]
,
data
[
'mutation_summary'
]
)
class
BaseExecutionEngine
(
AbstractExecutionEngine
):
class
BaseExecutionEngine
(
AbstractExecutionEngine
):
...
@@ -111,7 +123,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -111,7 +123,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
@
classmethod
@
classmethod
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
return
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
)
mutation_summary
=
get_mutation_summary
(
model
)
return
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
@
classmethod
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
def
trial_execute_graph
(
cls
)
->
None
:
...
...
nni/retiarii/execution/benchmark.py
View file @
357ec6ef
...
@@ -5,7 +5,7 @@ from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable
...
@@ -5,7 +5,7 @@ from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable
from
..graph
import
Model
from
..graph
import
Model
from
..integration_api
import
receive_trial_parameters
from
..integration_api
import
receive_trial_parameters
from
.base
import
BaseExecutionEngine
from
.base
import
BaseExecutionEngine
from
.
python
import
get_mutation_dict
from
.
utils
import
get_mutation_dict
class
BenchmarkGraphData
:
class
BenchmarkGraphData
:
...
...
nni/retiarii/execution/cgo_engine.py
View file @
357ec6ef
...
@@ -156,7 +156,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
...
@@ -156,7 +156,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
phy_models_and_placements
=
self
.
_assemble
(
logical
)
phy_models_and_placements
=
self
.
_assemble
(
logical
)
for
model
,
placement
,
grouped_models
in
phy_models_and_placements
:
for
model
,
placement
,
grouped_models
in
phy_models_and_placements
:
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
,
placement
=
placement
),
model
.
evaluator
)
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
,
placement
=
placement
),
model
.
evaluator
,
{}
)
placement_constraint
=
self
.
_extract_placement_constaint
(
placement
)
placement_constraint
=
self
.
_extract_placement_constaint
(
placement
)
trial_id
=
send_trial
(
data
.
dump
(),
placement_constraint
=
placement_constraint
)
trial_id
=
send_trial
(
data
.
dump
(),
placement_constraint
=
placement_constraint
)
# unique non-cpu devices used by the trial
# unique non-cpu devices used by the trial
...
...
nni/retiarii/execution/python.py
View file @
357ec6ef
from
typing
import
Dict
,
Any
,
List
from
typing
import
Dict
,
Any
from
..graph
import
Evaluator
,
Model
from
..graph
import
Evaluator
,
Model
from
..integration_api
import
receive_trial_parameters
from
..integration_api
import
receive_trial_parameters
from
..utils
import
ContextStack
,
import_
,
get_importable_name
from
..utils
import
ContextStack
,
import_
,
get_importable_name
from
.base
import
BaseExecutionEngine
from
.base
import
BaseExecutionEngine
from
.utils
import
get_mutation_dict
,
mutation_dict_to_summary
class
PythonGraphData
:
class
PythonGraphData
:
...
@@ -13,13 +14,15 @@ class PythonGraphData:
...
@@ -13,13 +14,15 @@ class PythonGraphData:
self
.
init_parameters
=
init_parameters
self
.
init_parameters
=
init_parameters
self
.
mutation
=
mutation
self
.
mutation
=
mutation
self
.
evaluator
=
evaluator
self
.
evaluator
=
evaluator
self
.
mutation_summary
=
mutation_dict_to_summary
(
mutation
)
def
dump
(
self
)
->
dict
:
def
dump
(
self
)
->
dict
:
return
{
return
{
'class_name'
:
self
.
class_name
,
'class_name'
:
self
.
class_name
,
'init_parameters'
:
self
.
init_parameters
,
'init_parameters'
:
self
.
init_parameters
,
'mutation'
:
self
.
mutation
,
'mutation'
:
self
.
mutation
,
'evaluator'
:
self
.
evaluator
'evaluator'
:
self
.
evaluator
,
'mutation_summary'
:
self
.
mutation_summary
}
}
@
staticmethod
@
staticmethod
...
@@ -55,13 +58,3 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
...
@@ -55,13 +58,3 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
with
ContextStack
(
'fixed'
,
graph_data
.
mutation
):
with
ContextStack
(
'fixed'
,
graph_data
.
mutation
):
graph_data
.
evaluator
.
_execute
(
_model
)
graph_data
.
evaluator
.
_execute
(
_model
)
def
_unpack_if_only_one
(
ele
:
List
[
Any
]):
if
len
(
ele
)
==
1
:
return
ele
[
0
]
return
ele
def
get_mutation_dict
(
model
:
Model
):
return
{
mut
.
mutator
.
label
:
_unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model
.
history
}
nni/retiarii/execution/utils.py
0 → 100644
View file @
357ec6ef
from
typing
import
Any
,
List
from
..graph
import
Model
def
_unpack_if_only_one
(
ele
:
List
[
Any
]):
if
len
(
ele
)
==
1
:
return
ele
[
0
]
return
ele
def
get_mutation_dict
(
model
:
Model
):
return
{
mut
.
mutator
.
label
:
_unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model
.
history
}
def
mutation_dict_to_summary
(
mutation
:
dict
)
->
dict
:
mutation_summary
=
{}
for
label
,
samples
in
mutation
.
items
():
# FIXME: this check might be wrong
if
not
isinstance
(
samples
,
list
):
mutation_summary
[
label
]
=
samples
else
:
for
i
,
sample
in
enumerate
(
samples
):
mutation_summary
[
f
'
{
label
}
_
{
i
}
'
]
=
sample
return
mutation_summary
def
get_mutation_summary
(
model
:
Model
)
->
dict
:
mutation
=
get_mutation_dict
(
model
)
return
mutation_dict_to_summary
(
mutation
)
nni/retiarii/experiment/pytorch.py
View file @
357ec6ef
...
@@ -28,13 +28,14 @@ from ..codegen import model_to_pytorch_script
...
@@ -28,13 +28,14 @@ from ..codegen import model_to_pytorch_script
from
..converter
import
convert_to_graph
from
..converter
import
convert_to_graph
from
..converter.graph_gen
import
GraphConverterWithShape
from
..converter.graph_gen
import
GraphConverterWithShape
from
..execution
import
list_models
,
set_execution_engine
from
..execution
import
list_models
,
set_execution_engine
from
..execution.
python
import
get_mutation_dict
from
..execution.
utils
import
get_mutation_dict
from
..graph
import
Evaluator
from
..graph
import
Evaluator
from
..integration
import
RetiariiAdvisor
from
..integration
import
RetiariiAdvisor
from
..mutator
import
Mutator
from
..mutator
import
Mutator
from
..nn.pytorch.mutator
import
extract_mutation_from_pt_module
,
process_inline_mutation
from
..nn.pytorch.mutator
import
extract_mutation_from_pt_module
,
process_inline_mutation
from
..oneshot.interface
import
BaseOneShotTrainer
from
..oneshot.interface
import
BaseOneShotTrainer
from
..strategy
import
BaseStrategy
from
..strategy
import
BaseStrategy
from
..strategy.utils
import
dry_run_for_formatted_search_space
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -193,6 +194,8 @@ class RetiariiExperiment(Experiment):
...
@@ -193,6 +194,8 @@ class RetiariiExperiment(Experiment):
)
)
_logger
.
info
(
'Start strategy...'
)
_logger
.
info
(
'Start strategy...'
)
search_space
=
dry_run_for_formatted_search_space
(
base_model_ir
,
self
.
applied_mutators
)
self
.
update_search_space
(
search_space
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
_logger
.
info
(
'Strategy exit'
)
_logger
.
info
(
'Strategy exit'
)
# TODO: find out a proper way to show no more trial message on WebUI
# TODO: find out a proper way to show no more trial message on WebUI
...
...
nni/retiarii/integration_api.py
View file @
357ec6ef
...
@@ -31,7 +31,6 @@ def send_trial(parameters: dict, placement_constraint=None) -> int:
...
@@ -31,7 +31,6 @@ def send_trial(parameters: dict, placement_constraint=None) -> int:
"""
"""
return
get_advisor
().
send_trial
(
parameters
,
placement_constraint
)
return
get_advisor
().
send_trial
(
parameters
,
placement_constraint
)
def
receive_trial_parameters
()
->
dict
:
def
receive_trial_parameters
()
->
dict
:
"""
"""
Received a new trial. Executed on trial end.
Received a new trial. Executed on trial end.
...
...
nni/retiarii/strategy/local_debug_strategy.py
View file @
357ec6ef
...
@@ -8,6 +8,7 @@ import string
...
@@ -8,6 +8,7 @@ import string
from
..
import
Sampler
,
codegen
,
utils
from
..
import
Sampler
,
codegen
,
utils
from
..execution.base
import
BaseGraphData
from
..execution.base
import
BaseGraphData
from
..execution.utils
import
get_mutation_summary
from
.base
import
BaseStrategy
from
.base
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -22,7 +23,8 @@ class _LocalDebugStrategy(BaseStrategy):
...
@@ -22,7 +23,8 @@ class _LocalDebugStrategy(BaseStrategy):
"""
"""
def
run_one_model
(
self
,
model
):
def
run_one_model
(
self
,
model
):
graph_data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
)
mutation_summary
=
get_mutation_summary
(
model
)
graph_data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
random_str
=
''
.
join
(
random
.
choice
(
string
.
ascii_uppercase
+
string
.
digits
)
for
_
in
range
(
6
))
random_str
=
''
.
join
(
random
.
choice
(
string
.
ascii_uppercase
+
string
.
digits
)
for
_
in
range
(
6
))
file_name
=
f
'_generated_model/
{
random_str
}
.py'
file_name
=
f
'_generated_model/
{
random_str
}
.py'
os
.
makedirs
(
os
.
path
.
dirname
(
file_name
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
file_name
),
exist_ok
=
True
)
...
...
nni/retiarii/strategy/utils.py
View file @
357ec6ef
...
@@ -27,6 +27,16 @@ def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any,
...
@@ -27,6 +27,16 @@ def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any,
search_space
[(
mutator
,
i
)]
=
candidates
search_space
[(
mutator
,
i
)]
=
candidates
return
search_space
return
search_space
def
dry_run_for_formatted_search_space
(
model
:
Model
,
mutators
:
List
[
Mutator
])
->
Dict
[
Any
,
Dict
[
Any
,
Any
]]:
search_space
=
collections
.
OrderedDict
()
for
mutator
in
mutators
:
recorded_candidates
,
model
=
mutator
.
dry_run
(
model
)
if
len
(
recorded_candidates
)
==
1
:
search_space
[
mutator
.
label
]
=
{
'_type'
:
'choice'
,
'_value'
:
recorded_candidates
[
0
]}
else
:
for
i
,
candidate
in
enumerate
(
recorded_candidates
):
search_space
[
f
'
{
mutator
.
label
}
_
{
i
}
'
]
=
{
'_type'
:
'choice'
,
'_value'
:
candidate
}
return
search_space
def
get_targeted_model
(
base_model
:
Model
,
mutators
:
List
[
Mutator
],
sample
:
dict
)
->
Model
:
def
get_targeted_model
(
base_model
:
Model
,
mutators
:
List
[
Mutator
],
sample
:
dict
)
->
Model
:
sampler
=
_FixedSampler
(
sample
)
sampler
=
_FixedSampler
(
sample
)
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
357ec6ef
...
@@ -8,7 +8,7 @@ import torch.nn.functional as F
...
@@ -8,7 +8,7 @@ import torch.nn.functional as F
from
nni.retiarii
import
InvalidMutation
,
Sampler
,
basic_unit
from
nni.retiarii
import
InvalidMutation
,
Sampler
,
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.execution.
python
import
_unpack_if_only_one
from
nni.retiarii.execution.
utils
import
_unpack_if_only_one
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
,
extract_mutation_from_pt_module
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
,
extract_mutation_from_pt_module
from
nni.retiarii.serializer
import
model_wrapper
from
nni.retiarii.serializer
import
model_wrapper
from
nni.retiarii.utils
import
ContextStack
from
nni.retiarii.utils
import
ContextStack
...
...
ts/nni_manager/core/nnimanager.ts
View file @
357ec6ef
...
@@ -513,8 +513,9 @@ class NNIManager implements Manager {
...
@@ -513,8 +513,9 @@ class NNIManager implements Manager {
if
(
this
.
dispatcher
===
undefined
)
{
if
(
this
.
dispatcher
===
undefined
)
{
throw
new
Error
(
'
Error: tuner has not been setup
'
);
throw
new
Error
(
'
Error: tuner has not been setup
'
);
}
}
this
.
log
.
info
(
`Updated search space
${
searchSpace
}
`
);
this
.
dispatcher
.
sendCommand
(
UPDATE_SEARCH_SPACE
,
searchSpace
);
this
.
dispatcher
.
sendCommand
(
UPDATE_SEARCH_SPACE
,
searchSpace
);
this
.
experimentProfile
.
params
.
searchSpace
=
searchSpace
;
this
.
experimentProfile
.
params
.
searchSpace
=
JSON
.
parse
(
searchSpace
)
;
return
;
return
;
}
}
...
...
ts/webui/src/static/interface.ts
View file @
357ec6ef
...
@@ -228,6 +228,10 @@ interface SearchItems {
...
@@ -228,6 +228,10 @@ interface SearchItems {
isChoice
:
boolean
;
// for parameters: type = choice and status also as choice type
isChoice
:
boolean
;
// for parameters: type = choice and status also as choice type
}
}
interface
RetiariiParameter
{
mutation_summary
:
object
;
// retiarii experiment's parameter
}
export
{
export
{
TableObj
,
TableObj
,
TableRecord
,
TableRecord
,
...
@@ -253,5 +257,6 @@ export {
...
@@ -253,5 +257,6 @@ export {
SortInfo
,
SortInfo
,
AllExperimentList
,
AllExperimentList
,
Tensorboard
,
Tensorboard
,
SearchItems
SearchItems
,
RetiariiParameter
};
};
ts/webui/src/static/model/trial.ts
View file @
357ec6ef
...
@@ -7,7 +7,8 @@ import {
...
@@ -7,7 +7,8 @@ import {
Parameters
,
Parameters
,
FinalType
,
FinalType
,
MultipleAxes
,
MultipleAxes
,
SingleAxis
SingleAxis
,
RetiariiParameter
}
from
'
../interface
'
;
}
from
'
../interface
'
;
import
{
import
{
getFinal
,
getFinal
,
...
@@ -31,9 +32,11 @@ function inferTrialParameters(
...
@@ -31,9 +32,11 @@ function inferTrialParameters(
space
:
MultipleAxes
,
space
:
MultipleAxes
,
prefix
:
string
=
''
prefix
:
string
=
''
):
[
Map
<
SingleAxis
,
any
>
,
Map
<
string
,
any
>
]
{
):
[
Map
<
SingleAxis
,
any
>
,
Map
<
string
,
any
>
]
{
const
latestedParamObj
=
'
mutation_summary
'
in
paramObj
?
(
paramObj
as
RetiariiParameter
).
mutation_summary
:
paramObj
;
const
parameters
=
new
Map
<
SingleAxis
,
any
>
();
const
parameters
=
new
Map
<
SingleAxis
,
any
>
();
const
unexpectedEntries
=
new
Map
<
string
,
any
>
();
const
unexpectedEntries
=
new
Map
<
string
,
any
>
();
for
(
const
[
k
,
v
]
of
Object
.
entries
(
p
aramObj
))
{
for
(
const
[
k
,
v
]
of
Object
.
entries
(
latestedP
aramObj
))
{
// prefix can be a good fallback when corresponding item is not found in namespace
// prefix can be a good fallback when corresponding item is not found in namespace
const
axisKey
=
space
.
axes
.
get
(
k
);
const
axisKey
=
space
.
axes
.
get
(
k
);
if
(
prefix
&&
k
===
'
_name
'
)
continue
;
if
(
prefix
&&
k
===
'
_name
'
)
continue
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment