Unverified Commit 357ec6ef authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[retiarii] support visualize model space with the hpo chart on webui (#4304)

parent 844f670b
...@@ -8,27 +8,39 @@ import string ...@@ -8,27 +8,39 @@ import string
from typing import Any, Dict, Iterable, List from typing import Any, Dict, Iterable, List
from .interface import AbstractExecutionEngine, AbstractGraphListener from .interface import AbstractExecutionEngine, AbstractGraphListener
from .utils import get_mutation_summary
from .. import codegen, utils from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Evaluator from ..graph import Model, ModelStatus, MetricData, Evaluator
from ..integration_api import send_trial, receive_trial_parameters, get_advisor from ..integration_api import send_trial, receive_trial_parameters, get_advisor
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class BaseGraphData: class BaseGraphData:
def __init__(self, model_script: str, evaluator: Evaluator) -> None: """
Attributes
----------
model_script
code of an instantiated PyTorch model
evaluator
training approach for model_script
mutation_summary
a dict of all the choices during mutations in the HPO search space format
"""
def __init__(self, model_script: str, evaluator: Evaluator, mutation_summary: dict) -> None:
self.model_script = model_script self.model_script = model_script
self.evaluator = evaluator self.evaluator = evaluator
self.mutation_summary = mutation_summary
def dump(self) -> dict: def dump(self) -> dict:
return { return {
'model_script': self.model_script, 'model_script': self.model_script,
'evaluator': self.evaluator 'evaluator': self.evaluator,
'mutation_summary': self.mutation_summary
} }
@staticmethod @staticmethod
def load(data) -> 'BaseGraphData': def load(data) -> 'BaseGraphData':
return BaseGraphData(data['model_script'], data['evaluator']) return BaseGraphData(data['model_script'], data['evaluator'], data['mutation_summary'])
class BaseExecutionEngine(AbstractExecutionEngine): class BaseExecutionEngine(AbstractExecutionEngine):
...@@ -111,7 +123,8 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -111,7 +123,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
@classmethod @classmethod
def pack_model_data(cls, model: Model) -> Any: def pack_model_data(cls, model: Model) -> Any:
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) mutation_summary = get_mutation_summary(model)
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary)
@classmethod @classmethod
def trial_execute_graph(cls) -> None: def trial_execute_graph(cls) -> None:
......
...@@ -5,7 +5,7 @@ from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable ...@@ -5,7 +5,7 @@ from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable
from ..graph import Model from ..graph import Model
from ..integration_api import receive_trial_parameters from ..integration_api import receive_trial_parameters
from .base import BaseExecutionEngine from .base import BaseExecutionEngine
from .python import get_mutation_dict from .utils import get_mutation_dict
class BenchmarkGraphData: class BenchmarkGraphData:
......
...@@ -156,7 +156,7 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -156,7 +156,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
phy_models_and_placements = self._assemble(logical) phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements: for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator) data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), model.evaluator, {})
placement_constraint = self._extract_placement_constaint(placement) placement_constraint = self._extract_placement_constaint(placement)
trial_id = send_trial(data.dump(), placement_constraint=placement_constraint) trial_id = send_trial(data.dump(), placement_constraint=placement_constraint)
# unique non-cpu devices used by the trial # unique non-cpu devices used by the trial
......
from typing import Dict, Any, List from typing import Dict, Any
from ..graph import Evaluator, Model from ..graph import Evaluator, Model
from ..integration_api import receive_trial_parameters from ..integration_api import receive_trial_parameters
from ..utils import ContextStack, import_, get_importable_name from ..utils import ContextStack, import_, get_importable_name
from .base import BaseExecutionEngine from .base import BaseExecutionEngine
from .utils import get_mutation_dict, mutation_dict_to_summary
class PythonGraphData: class PythonGraphData:
...@@ -13,13 +14,15 @@ class PythonGraphData: ...@@ -13,13 +14,15 @@ class PythonGraphData:
self.init_parameters = init_parameters self.init_parameters = init_parameters
self.mutation = mutation self.mutation = mutation
self.evaluator = evaluator self.evaluator = evaluator
self.mutation_summary = mutation_dict_to_summary(mutation)
def dump(self) -> dict: def dump(self) -> dict:
return { return {
'class_name': self.class_name, 'class_name': self.class_name,
'init_parameters': self.init_parameters, 'init_parameters': self.init_parameters,
'mutation': self.mutation, 'mutation': self.mutation,
'evaluator': self.evaluator 'evaluator': self.evaluator,
'mutation_summary': self.mutation_summary
} }
@staticmethod @staticmethod
...@@ -55,13 +58,3 @@ class PurePythonExecutionEngine(BaseExecutionEngine): ...@@ -55,13 +58,3 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
with ContextStack('fixed', graph_data.mutation): with ContextStack('fixed', graph_data.mutation):
graph_data.evaluator._execute(_model) graph_data.evaluator._execute(_model)
def _unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele
def get_mutation_dict(model: Model):
return {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
from typing import Any, List
from ..graph import Model
def _unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele
def get_mutation_dict(model: Model):
return {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
def mutation_dict_to_summary(mutation: dict) -> dict:
mutation_summary = {}
for label, samples in mutation.items():
# FIXME: this check might be wrong
if not isinstance(samples, list):
mutation_summary[label] = samples
else:
for i, sample in enumerate(samples):
mutation_summary[f'{label}_{i}'] = sample
return mutation_summary
def get_mutation_summary(model: Model) -> dict:
mutation = get_mutation_dict(model)
return mutation_dict_to_summary(mutation)
...@@ -28,13 +28,14 @@ from ..codegen import model_to_pytorch_script ...@@ -28,13 +28,14 @@ from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph from ..converter import convert_to_graph
from ..converter.graph_gen import GraphConverterWithShape from ..converter.graph_gen import GraphConverterWithShape
from ..execution import list_models, set_execution_engine from ..execution import list_models, set_execution_engine
from ..execution.python import get_mutation_dict from ..execution.utils import get_mutation_dict
from ..graph import Evaluator from ..graph import Evaluator
from ..integration import RetiariiAdvisor from ..integration import RetiariiAdvisor
from ..mutator import Mutator from ..mutator import Mutator
from ..nn.pytorch.mutator import extract_mutation_from_pt_module, process_inline_mutation from ..nn.pytorch.mutator import extract_mutation_from_pt_module, process_inline_mutation
from ..oneshot.interface import BaseOneShotTrainer from ..oneshot.interface import BaseOneShotTrainer
from ..strategy import BaseStrategy from ..strategy import BaseStrategy
from ..strategy.utils import dry_run_for_formatted_search_space
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -193,6 +194,8 @@ class RetiariiExperiment(Experiment): ...@@ -193,6 +194,8 @@ class RetiariiExperiment(Experiment):
) )
_logger.info('Start strategy...') _logger.info('Start strategy...')
search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators)
self.update_search_space(search_space)
self.strategy.run(base_model_ir, self.applied_mutators) self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit') _logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI # TODO: find out a proper way to show no more trial message on WebUI
......
...@@ -31,7 +31,6 @@ def send_trial(parameters: dict, placement_constraint=None) -> int: ...@@ -31,7 +31,6 @@ def send_trial(parameters: dict, placement_constraint=None) -> int:
""" """
return get_advisor().send_trial(parameters, placement_constraint) return get_advisor().send_trial(parameters, placement_constraint)
def receive_trial_parameters() -> dict: def receive_trial_parameters() -> dict:
""" """
Received a new trial. Executed on trial end. Received a new trial. Executed on trial end.
......
...@@ -8,6 +8,7 @@ import string ...@@ -8,6 +8,7 @@ import string
from .. import Sampler, codegen, utils from .. import Sampler, codegen, utils
from ..execution.base import BaseGraphData from ..execution.base import BaseGraphData
from ..execution.utils import get_mutation_summary
from .base import BaseStrategy from .base import BaseStrategy
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -22,7 +23,8 @@ class _LocalDebugStrategy(BaseStrategy): ...@@ -22,7 +23,8 @@ class _LocalDebugStrategy(BaseStrategy):
""" """
def run_one_model(self, model): def run_one_model(self, model):
graph_data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) mutation_summary = get_mutation_summary(model)
graph_data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary)
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6)) random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py' file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True) os.makedirs(os.path.dirname(file_name), exist_ok=True)
......
...@@ -27,6 +27,16 @@ def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, ...@@ -27,6 +27,16 @@ def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any,
search_space[(mutator, i)] = candidates search_space[(mutator, i)] = candidates
return search_space return search_space
def dry_run_for_formatted_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, Dict[Any, Any]]:
search_space = collections.OrderedDict()
for mutator in mutators:
recorded_candidates, model = mutator.dry_run(model)
if len(recorded_candidates) == 1:
search_space[mutator.label] = {'_type': 'choice', '_value': recorded_candidates[0]}
else:
for i, candidate in enumerate(recorded_candidates):
search_space[f'{mutator.label}_{i}'] = {'_type': 'choice', '_value': candidate}
return search_space
def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) -> Model: def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) -> Model:
sampler = _FixedSampler(sample) sampler = _FixedSampler(sample)
......
...@@ -8,7 +8,7 @@ import torch.nn.functional as F ...@@ -8,7 +8,7 @@ import torch.nn.functional as F
from nni.retiarii import InvalidMutation, Sampler, basic_unit from nni.retiarii import InvalidMutation, Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.execution.python import _unpack_if_only_one from nni.retiarii.execution.utils import _unpack_if_only_one
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module from nni.retiarii.nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
from nni.retiarii.serializer import model_wrapper from nni.retiarii.serializer import model_wrapper
from nni.retiarii.utils import ContextStack from nni.retiarii.utils import ContextStack
......
...@@ -513,8 +513,9 @@ class NNIManager implements Manager { ...@@ -513,8 +513,9 @@ class NNIManager implements Manager {
if (this.dispatcher === undefined) { if (this.dispatcher === undefined) {
throw new Error('Error: tuner has not been setup'); throw new Error('Error: tuner has not been setup');
} }
this.log.info(`Updated search space ${searchSpace}`);
this.dispatcher.sendCommand(UPDATE_SEARCH_SPACE, searchSpace); this.dispatcher.sendCommand(UPDATE_SEARCH_SPACE, searchSpace);
this.experimentProfile.params.searchSpace = searchSpace; this.experimentProfile.params.searchSpace = JSON.parse(searchSpace);
return; return;
} }
......
...@@ -228,6 +228,10 @@ interface SearchItems { ...@@ -228,6 +228,10 @@ interface SearchItems {
isChoice: boolean; // for parameters: type = choice and status also as choice type isChoice: boolean; // for parameters: type = choice and status also as choice type
} }
interface RetiariiParameter {
mutation_summary: object; // retiarii experiment's parameter
}
export { export {
TableObj, TableObj,
TableRecord, TableRecord,
...@@ -253,5 +257,6 @@ export { ...@@ -253,5 +257,6 @@ export {
SortInfo, SortInfo,
AllExperimentList, AllExperimentList,
Tensorboard, Tensorboard,
SearchItems SearchItems,
RetiariiParameter
}; };
...@@ -7,7 +7,8 @@ import { ...@@ -7,7 +7,8 @@ import {
Parameters, Parameters,
FinalType, FinalType,
MultipleAxes, MultipleAxes,
SingleAxis SingleAxis,
RetiariiParameter
} from '../interface'; } from '../interface';
import { import {
getFinal, getFinal,
...@@ -31,9 +32,11 @@ function inferTrialParameters( ...@@ -31,9 +32,11 @@ function inferTrialParameters(
space: MultipleAxes, space: MultipleAxes,
prefix: string = '' prefix: string = ''
): [Map<SingleAxis, any>, Map<string, any>] { ): [Map<SingleAxis, any>, Map<string, any>] {
const latestedParamObj =
'mutation_summary' in paramObj ? (paramObj as RetiariiParameter).mutation_summary : paramObj;
const parameters = new Map<SingleAxis, any>(); const parameters = new Map<SingleAxis, any>();
const unexpectedEntries = new Map<string, any>(); const unexpectedEntries = new Map<string, any>();
for (const [k, v] of Object.entries(paramObj)) { for (const [k, v] of Object.entries(latestedParamObj)) {
// prefix can be a good fallback when corresponding item is not found in namespace // prefix can be a good fallback when corresponding item is not found in namespace
const axisKey = space.axes.get(k); const axisKey = space.axes.get(k);
if (prefix && k === '_name') continue; if (prefix && k === '_name') continue;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment