Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
357ec6ef
You need to sign in or sign up before continuing.
Unverified
Commit
357ec6ef
authored
Dec 10, 2021
by
QuanluZhang
Committed by
GitHub
Dec 10, 2021
Browse files
[retiarii] support visualize model space with the hpo chart on webui (#4304)
parent
844f670b
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
81 additions
and
27 deletions
+81
-27
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+18
-5
nni/retiarii/execution/benchmark.py
nni/retiarii/execution/benchmark.py
+1
-1
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+1
-1
nni/retiarii/execution/python.py
nni/retiarii/execution/python.py
+5
-12
nni/retiarii/execution/utils.py
nni/retiarii/execution/utils.py
+25
-0
nni/retiarii/experiment/pytorch.py
nni/retiarii/experiment/pytorch.py
+4
-1
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+0
-1
nni/retiarii/strategy/local_debug_strategy.py
nni/retiarii/strategy/local_debug_strategy.py
+3
-1
nni/retiarii/strategy/utils.py
nni/retiarii/strategy/utils.py
+10
-0
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+1
-1
ts/nni_manager/core/nnimanager.ts
ts/nni_manager/core/nnimanager.ts
+2
-1
ts/webui/src/static/interface.ts
ts/webui/src/static/interface.ts
+6
-1
ts/webui/src/static/model/trial.ts
ts/webui/src/static/model/trial.ts
+5
-2
No files found.
nni/retiarii/execution/base.py
View file @
357ec6ef
...
...
@@ -8,27 +8,39 @@ import string
from
typing
import
Any
,
Dict
,
Iterable
,
List
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
from
.utils
import
get_mutation_summary
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
,
Evaluator
from
..integration_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
_logger
=
logging
.
getLogger
(
__name__
)
class
BaseGraphData
:
def
__init__
(
self
,
model_script
:
str
,
evaluator
:
Evaluator
)
->
None
:
"""
Attributes
----------
model_script
code of an instantiated PyTorch model
evaluator
training approach for model_script
mutation_summary
a dict of all the choices during mutations in the HPO search space format
"""
def
__init__
(
self
,
model_script
:
str
,
evaluator
:
Evaluator
,
mutation_summary
:
dict
)
->
None
:
self
.
model_script
=
model_script
self
.
evaluator
=
evaluator
self
.
mutation_summary
=
mutation_summary
def
dump
(
self
)
->
dict
:
return
{
'model_script'
:
self
.
model_script
,
'evaluator'
:
self
.
evaluator
'evaluator'
:
self
.
evaluator
,
'mutation_summary'
:
self
.
mutation_summary
}
@
staticmethod
def
load
(
data
)
->
'BaseGraphData'
:
return
BaseGraphData
(
data
[
'model_script'
],
data
[
'evaluator'
])
return
BaseGraphData
(
data
[
'model_script'
],
data
[
'evaluator'
]
,
data
[
'mutation_summary'
]
)
class
BaseExecutionEngine
(
AbstractExecutionEngine
):
...
...
@@ -111,7 +123,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
@
classmethod
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
return
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
)
mutation_summary
=
get_mutation_summary
(
model
)
return
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
...
...
nni/retiarii/execution/benchmark.py
View file @
357ec6ef
...
...
@@ -5,7 +5,7 @@ from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable
from
..graph
import
Model
from
..integration_api
import
receive_trial_parameters
from
.base
import
BaseExecutionEngine
from
.
python
import
get_mutation_dict
from
.
utils
import
get_mutation_dict
class
BenchmarkGraphData
:
...
...
nni/retiarii/execution/cgo_engine.py
View file @
357ec6ef
...
...
@@ -156,7 +156,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
phy_models_and_placements
=
self
.
_assemble
(
logical
)
for
model
,
placement
,
grouped_models
in
phy_models_and_placements
:
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
,
placement
=
placement
),
model
.
evaluator
)
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
,
placement
=
placement
),
model
.
evaluator
,
{}
)
placement_constraint
=
self
.
_extract_placement_constaint
(
placement
)
trial_id
=
send_trial
(
data
.
dump
(),
placement_constraint
=
placement_constraint
)
# unique non-cpu devices used by the trial
...
...
nni/retiarii/execution/python.py
View file @
357ec6ef
from
typing
import
Dict
,
Any
,
List
from
typing
import
Dict
,
Any
from
..graph
import
Evaluator
,
Model
from
..integration_api
import
receive_trial_parameters
from
..utils
import
ContextStack
,
import_
,
get_importable_name
from
.base
import
BaseExecutionEngine
from
.utils
import
get_mutation_dict
,
mutation_dict_to_summary
class
PythonGraphData
:
...
...
@@ -13,13 +14,15 @@ class PythonGraphData:
self
.
init_parameters
=
init_parameters
self
.
mutation
=
mutation
self
.
evaluator
=
evaluator
self
.
mutation_summary
=
mutation_dict_to_summary
(
mutation
)
def
dump
(
self
)
->
dict
:
return
{
'class_name'
:
self
.
class_name
,
'init_parameters'
:
self
.
init_parameters
,
'mutation'
:
self
.
mutation
,
'evaluator'
:
self
.
evaluator
'evaluator'
:
self
.
evaluator
,
'mutation_summary'
:
self
.
mutation_summary
}
@
staticmethod
...
...
@@ -55,13 +58,3 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
with
ContextStack
(
'fixed'
,
graph_data
.
mutation
):
graph_data
.
evaluator
.
_execute
(
_model
)
def
_unpack_if_only_one
(
ele
:
List
[
Any
]):
if
len
(
ele
)
==
1
:
return
ele
[
0
]
return
ele
def
get_mutation_dict
(
model
:
Model
):
return
{
mut
.
mutator
.
label
:
_unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model
.
history
}
nni/retiarii/execution/utils.py
0 → 100644
View file @
357ec6ef
from
typing
import
Any
,
List
from
..graph
import
Model
def
_unpack_if_only_one
(
ele
:
List
[
Any
]):
if
len
(
ele
)
==
1
:
return
ele
[
0
]
return
ele
def
get_mutation_dict
(
model
:
Model
):
return
{
mut
.
mutator
.
label
:
_unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model
.
history
}
def
mutation_dict_to_summary
(
mutation
:
dict
)
->
dict
:
mutation_summary
=
{}
for
label
,
samples
in
mutation
.
items
():
# FIXME: this check might be wrong
if
not
isinstance
(
samples
,
list
):
mutation_summary
[
label
]
=
samples
else
:
for
i
,
sample
in
enumerate
(
samples
):
mutation_summary
[
f
'
{
label
}
_
{
i
}
'
]
=
sample
return
mutation_summary
def
get_mutation_summary
(
model
:
Model
)
->
dict
:
mutation
=
get_mutation_dict
(
model
)
return
mutation_dict_to_summary
(
mutation
)
nni/retiarii/experiment/pytorch.py
View file @
357ec6ef
...
...
@@ -28,13 +28,14 @@ from ..codegen import model_to_pytorch_script
from
..converter
import
convert_to_graph
from
..converter.graph_gen
import
GraphConverterWithShape
from
..execution
import
list_models
,
set_execution_engine
from
..execution.
python
import
get_mutation_dict
from
..execution.
utils
import
get_mutation_dict
from
..graph
import
Evaluator
from
..integration
import
RetiariiAdvisor
from
..mutator
import
Mutator
from
..nn.pytorch.mutator
import
extract_mutation_from_pt_module
,
process_inline_mutation
from
..oneshot.interface
import
BaseOneShotTrainer
from
..strategy
import
BaseStrategy
from
..strategy.utils
import
dry_run_for_formatted_search_space
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -193,6 +194,8 @@ class RetiariiExperiment(Experiment):
)
_logger
.
info
(
'Start strategy...'
)
search_space
=
dry_run_for_formatted_search_space
(
base_model_ir
,
self
.
applied_mutators
)
self
.
update_search_space
(
search_space
)
self
.
strategy
.
run
(
base_model_ir
,
self
.
applied_mutators
)
_logger
.
info
(
'Strategy exit'
)
# TODO: find out a proper way to show no more trial message on WebUI
...
...
nni/retiarii/integration_api.py
View file @
357ec6ef
...
...
@@ -31,7 +31,6 @@ def send_trial(parameters: dict, placement_constraint=None) -> int:
"""
return
get_advisor
().
send_trial
(
parameters
,
placement_constraint
)
def
receive_trial_parameters
()
->
dict
:
"""
Received a new trial. Executed on trial end.
...
...
nni/retiarii/strategy/local_debug_strategy.py
View file @
357ec6ef
...
...
@@ -8,6 +8,7 @@ import string
from
..
import
Sampler
,
codegen
,
utils
from
..execution.base
import
BaseGraphData
from
..execution.utils
import
get_mutation_summary
from
.base
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -22,7 +23,8 @@ class _LocalDebugStrategy(BaseStrategy):
"""
def
run_one_model
(
self
,
model
):
graph_data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
)
mutation_summary
=
get_mutation_summary
(
model
)
graph_data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
random_str
=
''
.
join
(
random
.
choice
(
string
.
ascii_uppercase
+
string
.
digits
)
for
_
in
range
(
6
))
file_name
=
f
'_generated_model/
{
random_str
}
.py'
os
.
makedirs
(
os
.
path
.
dirname
(
file_name
),
exist_ok
=
True
)
...
...
nni/retiarii/strategy/utils.py
View file @
357ec6ef
...
...
@@ -27,6 +27,16 @@ def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any,
search_space
[(
mutator
,
i
)]
=
candidates
return
search_space
def
dry_run_for_formatted_search_space
(
model
:
Model
,
mutators
:
List
[
Mutator
])
->
Dict
[
Any
,
Dict
[
Any
,
Any
]]:
search_space
=
collections
.
OrderedDict
()
for
mutator
in
mutators
:
recorded_candidates
,
model
=
mutator
.
dry_run
(
model
)
if
len
(
recorded_candidates
)
==
1
:
search_space
[
mutator
.
label
]
=
{
'_type'
:
'choice'
,
'_value'
:
recorded_candidates
[
0
]}
else
:
for
i
,
candidate
in
enumerate
(
recorded_candidates
):
search_space
[
f
'
{
mutator
.
label
}
_
{
i
}
'
]
=
{
'_type'
:
'choice'
,
'_value'
:
candidate
}
return
search_space
def
get_targeted_model
(
base_model
:
Model
,
mutators
:
List
[
Mutator
],
sample
:
dict
)
->
Model
:
sampler
=
_FixedSampler
(
sample
)
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
357ec6ef
...
...
@@ -8,7 +8,7 @@ import torch.nn.functional as F
from
nni.retiarii
import
InvalidMutation
,
Sampler
,
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.execution.
python
import
_unpack_if_only_one
from
nni.retiarii.execution.
utils
import
_unpack_if_only_one
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
,
extract_mutation_from_pt_module
from
nni.retiarii.serializer
import
model_wrapper
from
nni.retiarii.utils
import
ContextStack
...
...
ts/nni_manager/core/nnimanager.ts
View file @
357ec6ef
...
...
@@ -513,8 +513,9 @@ class NNIManager implements Manager {
if
(
this
.
dispatcher
===
undefined
)
{
throw
new
Error
(
'
Error: tuner has not been setup
'
);
}
this
.
log
.
info
(
`Updated search space
${
searchSpace
}
`
);
this
.
dispatcher
.
sendCommand
(
UPDATE_SEARCH_SPACE
,
searchSpace
);
this
.
experimentProfile
.
params
.
searchSpace
=
searchSpace
;
this
.
experimentProfile
.
params
.
searchSpace
=
JSON
.
parse
(
searchSpace
)
;
return
;
}
...
...
ts/webui/src/static/interface.ts
View file @
357ec6ef
...
...
@@ -228,6 +228,10 @@ interface SearchItems {
isChoice
:
boolean
;
// for parameters: type = choice and status also as choice type
}
interface
RetiariiParameter
{
mutation_summary
:
object
;
// retiarii experiment's parameter
}
export
{
TableObj
,
TableRecord
,
...
...
@@ -253,5 +257,6 @@ export {
SortInfo
,
AllExperimentList
,
Tensorboard
,
SearchItems
SearchItems
,
RetiariiParameter
};
ts/webui/src/static/model/trial.ts
View file @
357ec6ef
...
...
@@ -7,7 +7,8 @@ import {
Parameters
,
FinalType
,
MultipleAxes
,
SingleAxis
SingleAxis
,
RetiariiParameter
}
from
'
../interface
'
;
import
{
getFinal
,
...
...
@@ -31,9 +32,11 @@ function inferTrialParameters(
space
:
MultipleAxes
,
prefix
:
string
=
''
):
[
Map
<
SingleAxis
,
any
>
,
Map
<
string
,
any
>
]
{
const
latestedParamObj
=
'
mutation_summary
'
in
paramObj
?
(
paramObj
as
RetiariiParameter
).
mutation_summary
:
paramObj
;
const
parameters
=
new
Map
<
SingleAxis
,
any
>
();
const
unexpectedEntries
=
new
Map
<
string
,
any
>
();
for
(
const
[
k
,
v
]
of
Object
.
entries
(
p
aramObj
))
{
for
(
const
[
k
,
v
]
of
Object
.
entries
(
latestedP
aramObj
))
{
// prefix can be a good fallback when corresponding item is not found in namespace
const
axisKey
=
space
.
axes
.
get
(
k
);
if
(
prefix
&&
k
===
'
_name
'
)
continue
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment