"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "b49cceb91e7ef02bfd3059fc71a6c4f32bf4aca1"
Unverified Commit 2f3c3951 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Pure-python execution engine (#3605)

parent 122b5b89
Advanced Tutorial Advanced Tutorial
================= =================
This document includes two parts. The first part explains the design decision of ``@basic_unit`` and ``serializer``. The second part is the tutorial of how to write a model space with mutators. Pure-python execution engine (experimental)
-------------------------------------------
If you are experiencing issues with TorchScript, or the generated model code by Retiarii, there is another execution engine called Pure-python execution engine which doesn't need the code-graph conversion. This should generally not affect models and strategies in most cases, but customized mutation might not be supported.
This will come as the default execution engine in future version of Retiarii.
Two steps are needed to enable this engine now.
1. Add ``@nni.retiarii.model_wrapper`` decorator outside the whole PyTorch model.
2. Add ``config.execution_engine = 'py'`` to ``RetiariiExeConfig``.
``@basic_unit`` and ``serializer`` ``@basic_unit`` and ``serializer``
---------------------------------- ----------------------------------
......
...@@ -5,4 +5,4 @@ from .operation import Operation ...@@ -5,4 +5,4 @@ from .operation import Operation
from .graph import * from .graph import *
from .execution import * from .execution import *
from .mutator import * from .mutator import *
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper
...@@ -15,19 +15,18 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener', ...@@ -15,19 +15,18 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener',
'list_models', 'submit_models', 'wait_models', 'query_available_resources', 'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec', 'budget_exhausted'] 'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']
def set_execution_engine(engine) -> None:
def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine global _execution_engine
if _execution_engine is None: if _execution_engine is None:
_execution_engine = engine _execution_engine = engine
else: else:
raise RuntimeError('execution engine is already set') raise RuntimeError('Execution engine is already set.')
def get_execution_engine() -> AbstractExecutionEngine: def get_execution_engine() -> AbstractExecutionEngine:
"""
Currently we assume the default execution engine is BaseExecutionEngine.
"""
global _execution_engine global _execution_engine
assert _execution_engine is not None, 'You need to set execution engine, before using it.'
return _execution_engine return _execution_engine
......
...@@ -5,7 +5,7 @@ import logging ...@@ -5,7 +5,7 @@ import logging
import os import os
import random import random
import string import string
from typing import Dict, Iterable, List from typing import Any, Dict, Iterable, List
from .interface import AbstractExecutionEngine, AbstractGraphListener from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils from .. import codegen, utils
...@@ -59,7 +59,7 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -59,7 +59,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def submit_models(self, *models: Model) -> None: def submit_models(self, *models: Model) -> None:
for model in models: for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) data = self.pack_model_data(model)
self._running_models[send_trial(data.dump())] = model self._running_models[send_trial(data.dump())] = model
self._history.append(model) self._history.append(model)
...@@ -108,6 +108,10 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -108,6 +108,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
advisor = get_advisor() advisor = get_advisor()
return advisor.stopping return advisor.stopping
@classmethod
def pack_model_data(cls, model: Model) -> Any:
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
@classmethod @classmethod
def trial_execute_graph(cls) -> None: def trial_execute_graph(cls) -> None:
""" """
......
from typing import Dict, Any, List
from ..graph import Evaluator, Model
from ..integration_api import receive_trial_parameters
from ..utils import ContextStack, import_, get_importable_name
from .base import BaseExecutionEngine
class PythonGraphData:
def __init__(self, class_name: str, init_parameters: Dict[str, Any],
mutation: Dict[str, Any], evaluator: Evaluator) -> None:
self.class_name = class_name
self.init_parameters = init_parameters
self.mutation = mutation
self.evaluator = evaluator
def dump(self) -> dict:
return {
'class_name': self.class_name,
'init_parameters': self.init_parameters,
'mutation': self.mutation,
'evaluator': self.evaluator
}
@staticmethod
def load(data) -> 'PythonGraphData':
return PythonGraphData(data['class_name'], data['init_parameters'], data['mutation'], data['evaluator'])
class PurePythonExecutionEngine(BaseExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
graph_data = PythonGraphData(get_importable_name(model.python_class, relocate_module=True),
model.python_init_params, mutation, model.evaluator)
return graph_data
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = PythonGraphData.load(receive_trial_parameters())
class _model(import_(graph_data.class_name)):
def __init__(self):
super().__init__(**graph_data.init_parameters)
with ContextStack('fixed', graph_data.mutation):
graph_data.evaluator._execute(_model)
def _unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele
...@@ -28,11 +28,11 @@ from nni.tools.nnictl.command_utils import kill_command ...@@ -28,11 +28,11 @@ from nni.tools.nnictl.command_utils import kill_command
from ..codegen import model_to_pytorch_script from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph from ..converter import convert_to_graph
from ..execution import list_models from ..execution import list_models, set_execution_engine
from ..graph import Model, Evaluator from ..graph import Model, Evaluator
from ..integration import RetiariiAdvisor from ..integration import RetiariiAdvisor
from ..mutator import Mutator from ..mutator import Mutator
from ..nn.pytorch.mutator import process_inline_mutation from ..nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
from ..strategy import BaseStrategy from ..strategy import BaseStrategy
from ..oneshot.interface import BaseOneShotTrainer from ..oneshot.interface import BaseOneShotTrainer
...@@ -43,7 +43,7 @@ _logger = logging.getLogger(__name__) ...@@ -43,7 +43,7 @@ _logger = logging.getLogger(__name__)
class RetiariiExeConfig(ConfigBase): class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None experiment_name: Optional[str] = None
search_space: Any = '' # TODO: remove search_space: Any = '' # TODO: remove
trial_command: str = 'python3 -m nni.retiarii.trial_entry' trial_command: str = '_reserved'
trial_code_directory: PathLike = '.' trial_code_directory: PathLike = '.'
trial_concurrency: int trial_concurrency: int
trial_gpu_number: int = 0 trial_gpu_number: int = 0
...@@ -55,21 +55,26 @@ class RetiariiExeConfig(ConfigBase): ...@@ -55,21 +55,26 @@ class RetiariiExeConfig(ConfigBase):
experiment_working_directory: PathLike = '~/nni-experiments' experiment_working_directory: PathLike = '~/nni-experiments'
# remove configuration of tuner/assessor/advisor # remove configuration of tuner/assessor/advisor
training_service: TrainingServiceConfig training_service: TrainingServiceConfig
execution_engine: str = 'base'
def __init__(self, training_service_platform: Optional[str] = None, **kwargs): def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if training_service_platform is not None: if training_service_platform is not None:
assert 'training_service' not in kwargs assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(platform = training_service_platform) self.training_service = util.training_service_config_factory(platform = training_service_platform)
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry base'
def __setattr__(self, key, value): def __setattr__(self, key, value):
fixed_attrs = {'search_space': '', fixed_attrs = {'search_space': '',
'trial_command': 'python3 -m nni.retiarii.trial_entry'} 'trial_command': '_reserved'}
if key in fixed_attrs and fixed_attrs[key] != value: if key in fixed_attrs and fixed_attrs[key] != value:
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!') raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
# 'trial_code_directory' is handled differently because the path will be converted to absolute path by us # 'trial_code_directory' is handled differently because the path will be converted to absolute path by us
if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)): if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)):
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!') raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
if key == 'execution_engine':
assert value in ['base', 'py', 'cgo'], f'The specified execution engine "{value}" is not supported.'
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value
self.__dict__[key] = value self.__dict__[key] = value
def validate(self, initialized_tuner: bool = False) -> None: def validate(self, initialized_tuner: bool = False) -> None:
...@@ -100,23 +105,27 @@ _validation_rules = { ...@@ -100,23 +105,27 @@ _validation_rules = {
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class') 'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
} }
def preprocess_model(base_model, trainer, applied_mutators): def preprocess_model(base_model, trainer, applied_mutators, full_ir=True):
# TODO: this logic might need to be refactored into execution engine
if full_ir:
try: try:
script_module = torch.jit.script(base_model) script_module = torch.jit.script(base_model)
except Exception as e: except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:') _logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e raise e
base_model_ir = convert_to_graph(script_module, base_model) base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.evaluator = trainer
# handle inline mutations # handle inline mutations
mutators = process_inline_mutation(base_model_ir) mutators = process_inline_mutation(base_model_ir)
if mutators is not None and applied_mutators: else:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, ' base_model_ir, mutators = extract_mutation_from_pt_module(base_model)
'do not use mutators when you use LayerChoice/InputChoice') base_model_ir.evaluator = trainer
if mutators is not None:
applied_mutators = mutators if mutators is not None and applied_mutators:
return base_model_ir, applied_mutators raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
applied_mutators = mutators
return base_model_ir, applied_mutators
def debug_mutated_model(base_model, trainer, applied_mutators): def debug_mutated_model(base_model, trainer, applied_mutators):
""" """
...@@ -160,7 +169,8 @@ class RetiariiExperiment(Experiment): ...@@ -160,7 +169,8 @@ class RetiariiExperiment(Experiment):
self._pipe: Optional[Pipe] = None self._pipe: Optional[Pipe] = None
def _start_strategy(self): def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.trainer, self.applied_mutators) base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py')
_logger.info('Start strategy...') _logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators) self.strategy.run(base_model_ir, self.applied_mutators)
...@@ -182,6 +192,18 @@ class RetiariiExperiment(Experiment): ...@@ -182,6 +192,18 @@ class RetiariiExperiment(Experiment):
""" """
atexit.register(self.stop) atexit.register(self.stop)
# we will probably need a execution engine factory to make this clean and elegant
if self.config.execution_engine == 'base':
from ..execution.base import BaseExecutionEngine
engine = BaseExecutionEngine()
elif self.config.execution_engine == 'cgo':
from ..execution.cgo_engine import CGOExecutionEngine
engine = CGOExecutionEngine()
elif self.config.execution_engine == 'py':
from ..execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine()
set_execution_engine(engine)
self.id = management.generate_experiment_id() self.id = management.generate_experiment_id()
if self.config.experiment_working_directory is not None: if self.config.experiment_working_directory is not None:
......
...@@ -9,12 +9,12 @@ import abc ...@@ -9,12 +9,12 @@ import abc
import copy import copy
import json import json
from enum import Enum from enum import Enum
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Union, overload) from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation from .operation import Cell, Operation, _IOPseudoOperation
from .utils import get_importable_name, import_, uid from .utils import get_importable_name, import_, uid
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData'] __all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData']
MetricData = Any MetricData = Any
...@@ -80,6 +80,10 @@ class Model: ...@@ -80,6 +80,10 @@ class Model:
Attributes Attributes
---------- ----------
python_class
Python class that base model is converted from.
python_init_params
Initialization parameters of python class.
status status
See `ModelStatus`. See `ModelStatus`.
root_graph root_graph
...@@ -102,6 +106,8 @@ class Model: ...@@ -102,6 +106,8 @@ class Model:
def __init__(self, _internal=False): def __init__(self, _internal=False):
assert _internal, '`Model()` is private, use `model.fork()` instead' assert _internal, '`Model()` is private, use `model.fork()` instead'
self.model_id: int = uid('model') self.model_id: int = uid('model')
self.python_class: Optional[Type] = None
self.python_init_params: Optional[Dict[str, Any]] = None
self.status: ModelStatus = ModelStatus.Mutating self.status: ModelStatus = ModelStatus.Mutating
...@@ -116,7 +122,8 @@ class Model: ...@@ -116,7 +122,8 @@ class Model:
def __repr__(self): def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \ return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})' f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics}, ' + \
f'python_class={self.python_class})'
@property @property
def root_graph(self) -> 'Graph': def root_graph(self) -> 'Graph':
...@@ -133,9 +140,12 @@ class Model: ...@@ -133,9 +140,12 @@ class Model:
""" """
new_model = Model(_internal=True) new_model = Model(_internal=True)
new_model._root_graph_name = self._root_graph_name new_model._root_graph_name = self._root_graph_name
new_model.python_class = self.python_class
new_model.python_init_params = self.python_init_params
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()} new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
new_model.evaluator = copy.deepcopy(self.evaluator) # TODO this may be a problem when evaluator is large new_model.evaluator = copy.deepcopy(self.evaluator) # TODO this may be a problem when evaluator is large
new_model.history = self.history + [self] new_model.history = [*self.history]
# Note: the history is not updated. It will be updated when the model is changed, that is in mutator.
return new_model return new_model
@staticmethod @staticmethod
...@@ -167,8 +177,8 @@ class Model: ...@@ -167,8 +177,8 @@ class Model:
def get_nodes_by_label(self, label: str) -> List['Node']: def get_nodes_by_label(self, label: str) -> List['Node']:
""" """
Traverse all the nodes to find the matched node(s) with the given name. Traverse all the nodes to find the matched node(s) with the given label.
There could be multiple nodes with the same name. Name space name can uniquely There could be multiple nodes with the same label. Name space name can uniquely
identify a graph or node. identify a graph or node.
NOTE: the implementation does not support the class abstration NOTE: the implementation does not support the class abstration
...@@ -493,6 +503,8 @@ class Node: ...@@ -493,6 +503,8 @@ class Node:
If two models have nodes with same ID, they are semantically the same node. If two models have nodes with same ID, they are semantically the same node.
name name
Mnemonic name. It should have an one-to-one mapping with ID. Mnemonic name. It should have an one-to-one mapping with ID.
label
Optional. If two nodes have the same label, they are considered same by the mutator.
operation operation
... ...
cell cell
...@@ -515,7 +527,7 @@ class Node: ...@@ -515,7 +527,7 @@ class Node:
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug # TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release # maybe we should copy it here or make Operation class immutable, in next release
self.operation: Operation = operation self.operation: Operation = operation
self.label: str = None self.label: Optional[str] = None
def __repr__(self): def __repr__(self):
return f'Node(id={self.id}, name={self.name}, label={self.label}, operation={self.operation})' return f'Node(id={self.id}, name={self.name}, label={self.label}, operation={self.operation})'
...@@ -673,6 +685,37 @@ class Edge: ...@@ -673,6 +685,37 @@ class Edge:
} }
class Mutation:
"""
An execution of mutation, which consists of four parts: a mutator, a list of decisions (choices),
the model that it comes from, and the model that it becomes.
In general cases, the mutation logs are not reliable and should not be replayed as the mutators can
be arbitrarily complex. However, for inline mutations, the labels correspond to mutator labels here,
this can be useful for metadata visualization and python execution mode.
Attributes
----------
mutator
Mutator.
samples
Decisions/choices.
from_
Model that is comes from.
to
Model that it becomes.
"""
def __init__(self, mutator: 'Mutator', samples: List[Any], from_: Model, to: Model): # noqa: F821
self.mutator: 'Mutator' = mutator # noqa: F821
self.samples: List[Any] = samples
self.from_: Model = from_
self.to: Model = to
def __repr__(self):
return f'Edge(mutator={self.mutator}, samples={self.samples}, from={self.from_}, to={self.to})'
class IllegalGraphError(ValueError): class IllegalGraphError(ValueError):
def __init__(self, graph, *args): def __init__(self, graph, *args):
self._debug_dump_graph(graph) self._debug_dump_graph(graph)
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import os
from typing import Any, Callable from typing import Any, Callable
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
...@@ -10,9 +9,6 @@ from nni.runtime.protocol import CommandType, send ...@@ -10,9 +9,6 @@ from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType from nni.utils import MetricType
from .graph import MetricData from .graph import MetricData
from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine
from .execution.api import set_execution_engine
from .integration_api import register_advisor from .integration_api import register_advisor
from .serializer import json_dumps, json_loads from .serializer import json_dumps, json_loads
...@@ -62,15 +58,6 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -62,15 +58,6 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.parameters_count = 0 self.parameters_count = 0
engine = self._create_execution_engine()
set_execution_engine(engine)
def _create_execution_engine(self):
if os.environ.get('CGO') == 'true':
return CGOExecutionEngine()
else:
return BaseExecutionEngine()
def handle_initialize(self, data): def handle_initialize(self, data):
"""callback for initializing the advisor """callback for initializing the advisor
Parameters Parameters
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from typing import (Any, Iterable, List, Optional) from typing import (Any, Iterable, List, Optional)
from .graph import Model from .graph import Model, Mutation, ModelStatus
__all__ = ['Sampler', 'Mutator'] __all__ = ['Sampler', 'Mutator']
...@@ -40,10 +40,13 @@ class Mutator: ...@@ -40,10 +40,13 @@ class Mutator:
and then use `Mutator.apply()` to mutate model. and then use `Mutator.apply()` to mutate model.
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates. For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion. # Method names are open for discussion.
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
""" """
def __init__(self, sampler: Optional[Sampler] = None): def __init__(self, sampler: Optional[Sampler] = None, label: Optional[str] = None):
self.sampler: Optional[Sampler] = sampler self.sampler: Optional[Sampler] = sampler
self.label: Optional[str] = label
self._cur_model: Optional[Model] = None self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None self._cur_choice_idx: Optional[int] = None
...@@ -64,9 +67,12 @@ class Mutator: ...@@ -64,9 +67,12 @@ class Mutator:
copy = model.fork() copy = model.fork()
self._cur_model = copy self._cur_model = copy
self._cur_choice_idx = 0 self._cur_choice_idx = 0
self._cur_samples = []
self.sampler.mutation_start(self, copy) self.sampler.mutation_start(self, copy)
self.mutate(copy) self.mutate(copy)
self.sampler.mutation_end(self, copy) self.sampler.mutation_end(self, copy)
copy.history.append(Mutation(self, self._cur_samples, model, copy))
copy.status = ModelStatus.Frozen
self._cur_model = None self._cur_model = None
self._cur_choice_idx = None self._cur_choice_idx = None
return copy return copy
...@@ -97,6 +103,7 @@ class Mutator: ...@@ -97,6 +103,7 @@ class Mutator:
""" """
assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None
ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx) ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx)
self._cur_samples.append(ret)
self._cur_choice_idx += 1 self._cur_choice_idx += 1
return ret return ret
......
...@@ -4,18 +4,32 @@ ...@@ -4,18 +4,32 @@
import copy import copy
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Any, List, Union, Dict from typing import Any, List, Union, Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...serializer import Translatable, basic_unit from ...serializer import Translatable, basic_unit
from ...utils import uid from ...utils import uid, get_current_context
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs'] __all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
def _generate_new_label(label: Optional[str]):
if label is None:
return '_mutation_' + str(uid('mutation'))
return label
def _get_fixed_value(label: str):
ret = get_current_context('fixed')
try:
return ret[_generate_new_label(label)]
except KeyError:
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
class LayerChoice(nn.Module): class LayerChoice(nn.Module):
""" """
Layer choice selects one of the ``candidates``, then apply it on inputs and return results. Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
...@@ -55,6 +69,16 @@ class LayerChoice(nn.Module): ...@@ -55,6 +69,16 @@ class LayerChoice(nn.Module):
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet. ``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
""" """
def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs):
try:
chosen = _get_fixed_value(label)
if isinstance(candidates, list):
return candidates[int(chosen)]
else:
return candidates[chosen]
except AssertionError:
return super().__new__(cls)
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs): def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs):
super(LayerChoice, self).__init__() super(LayerChoice, self).__init__()
if 'key' in kwargs: if 'key' in kwargs:
...@@ -65,7 +89,7 @@ class LayerChoice(nn.Module): ...@@ -65,7 +89,7 @@ class LayerChoice(nn.Module):
if 'reduction' in kwargs: if 'reduction' in kwargs:
warnings.warn(f'"reduction" is deprecated. Ignoring...') warnings.warn(f'"reduction" is deprecated. Ignoring...')
self.candidates = candidates self.candidates = candidates
self._label = label if label is not None else f'layerchoice_{uid()}' self._label = _generate_new_label(label)
self.names = [] self.names = []
if isinstance(candidates, OrderedDict): if isinstance(candidates, OrderedDict):
...@@ -163,6 +187,12 @@ class InputChoice(nn.Module): ...@@ -163,6 +187,12 @@ class InputChoice(nn.Module):
Identifier of the input choice. Identifier of the input choice.
""" """
def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs):
try:
return ChosenInputs(_get_fixed_value(label), reduction=reduction)
except AssertionError:
return super().__new__(cls)
def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs): def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs):
super(InputChoice, self).__init__() super(InputChoice, self).__init__()
if 'key' in kwargs: if 'key' in kwargs:
...@@ -176,7 +206,7 @@ class InputChoice(nn.Module): ...@@ -176,7 +206,7 @@ class InputChoice(nn.Module):
self.n_chosen = n_chosen self.n_chosen = n_chosen
self.reduction = reduction self.reduction = reduction
assert self.reduction in ['mean', 'concat', 'sum', 'none'] assert self.reduction in ['mean', 'concat', 'sum', 'none']
self._label = label if label is not None else f'inputchoice_{uid()}' self._label = _generate_new_label(label)
@property @property
def key(self): def key(self):
...@@ -265,10 +295,16 @@ class ValueChoice(Translatable, nn.Module): ...@@ -265,10 +295,16 @@ class ValueChoice(Translatable, nn.Module):
Identifier of the value choice. Identifier of the value choice.
""" """
def __new__(cls, candidates: List[Any], label: str = None):
try:
return _get_fixed_value(label)
except AssertionError:
return super().__new__(cls)
def __init__(self, candidates: List[Any], label: str = None): def __init__(self, candidates: List[Any], label: str = None):
super().__init__() super().__init__()
self.candidates = candidates self.candidates = candidates
self._label = label if label is not None else f'valuechoice_{uid()}' self._label = _generate_new_label(label)
self._accessor = [] self._accessor = []
@property @property
...@@ -297,6 +333,14 @@ class ValueChoice(Translatable, nn.Module): ...@@ -297,6 +333,14 @@ class ValueChoice(Translatable, nn.Module):
raise KeyError(''.join([f'[{a}]' for a in self._accessor]) + f' does not work on {value}') raise KeyError(''.join([f'[{a}]' for a in self._accessor]) + f' does not work on {value}')
return v return v
def __copy__(self):
return self
def __deepcopy__(self, memo):
new_item = ValueChoice(self.candidates, self.label)
new_item._accessor = [*self._accessor]
return new_item
def __getitem__(self, item): def __getitem__(self, item):
""" """
Get a sub-element of value choice. Get a sub-element of value choice.
...@@ -331,9 +375,9 @@ class ChosenInputs(nn.Module): ...@@ -331,9 +375,9 @@ class ChosenInputs(nn.Module):
The already-chosen version of InputChoice. The already-chosen version of InputChoice.
""" """
def __init__(self, chosen: List[int], reduction: str): def __init__(self, chosen: Union[List[int], int], reduction: str):
super().__init__() super().__init__()
self.chosen = chosen self.chosen = chosen if isinstance(chosen, list) else [chosen]
self.reduction = reduction self.reduction = reduction
def forward(self, candidate_inputs): def forward(self, candidate_inputs):
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import inspect
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
import torch.nn as nn
from ...mutator import Mutator from ...mutator import Mutator
from ...graph import Cell, Model, Node from ...graph import Cell, Graph, Model, ModelStatus, Node
from .api import ValueChoice from ...utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
class LayerChoiceMutator(Mutator): class LayerChoiceMutator(Mutator):
...@@ -40,7 +44,7 @@ class InputChoiceMutator(Mutator): ...@@ -40,7 +44,7 @@ class InputChoiceMutator(Mutator):
def mutate(self, model): def mutate(self, model):
n_candidates = self.nodes[0].operation.parameters['n_candidates'] n_candidates = self.nodes[0].operation.parameters['n_candidates']
n_chosen = self.nodes[0].operation.parameters['n_chosen'] n_chosen = self.nodes[0].operation.parameters['n_chosen']
candidates = list(range(n_candidates)) candidates = list(range(n_candidates))
chosen = [self.choice(candidates) for _ in range(n_chosen)] chosen = [self.choice(candidates) for _ in range(n_chosen)]
for node in self.nodes: for node in self.nodes:
...@@ -116,12 +120,96 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: ...@@ -116,12 +120,96 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = LayerChoiceMutator(node_list) mutator = LayerChoiceMutator(node_list)
applied_mutators.append(mutator) applied_mutators.append(mutator)
if applied_mutators: if applied_mutators:
return applied_mutators return applied_mutators
return None return None
# The following are written for pure-python mode
class ManyChooseManyMutator(Mutator):
"""
Choose based on labels. Will not affect the model itself.
"""
def __init__(self, label: Optional[str]):
super().__init__(label=label)
@staticmethod
def candidates(node):
if 'n_candidates' in node.operation.parameters:
return list(range(node.operation.parameters['n_candidates']))
else:
return node.operation.parameters['candidates']
@staticmethod
def number_of_chosen(node):
if 'n_chosen' in node.operation.parameters:
return node.operation.parameters['n_chosen']
return 1
def mutate(self, model: Model):
# this mutate does not have any effect, but it is recorded in the mutation history
for node in model.get_nodes_by_label(self.label):
for _ in range(self.number_of_chosen(node)):
self.choice(self.candidates(node))
break
def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Optional[List[Mutator]]]:
model = Model(_internal=True)
graph = Graph(model, uid(), '_model', _internal=True)._register()
model.python_class = pytorch_model.__class__
if len(inspect.signature(model.python_class.__init__).parameters) > 1:
if not hasattr(pytorch_model, '_init_parameters'):
raise ValueError('Please annotate the model with @serialize decorator in python execution mode '
'if your model has init parameters.')
model.python_init_params = pytorch_model._init_parameters
else:
model.python_init_params = {}
for name, module in pytorch_model.named_modules():
# tricky case: value choice that serves as parameters are stored in _init_parameters
if hasattr(module, '_init_parameters'):
for key, value in module._init_parameters.items():
if isinstance(value, ValueChoice):
node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates})
node.label = value.label
if isinstance(module, (LayerChoice, InputChoice, ValueChoice)):
# TODO: check the label of module and warn if it's auto-generated
pass
if isinstance(module, LayerChoice):
node = graph.add_node(name, 'LayerChoice', {'candidates': module.names})
node.label = module.label
if isinstance(module, InputChoice):
node = graph.add_node(name, 'InputChoice',
{'n_candidates': module.n_candidates, 'n_chosen': module.n_chosen})
node.label = module.label
if isinstance(module, ValueChoice):
node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates})
node.label = module.label
if isinstance(module, Placeholder):
raise NotImplementedError('Placeholder is not supported in python execution mode.')
model.status = ModelStatus.Frozen
if not graph.hidden_nodes:
return model, None
mutators = []
for nodes in _group_by_label_and_type(graph.hidden_nodes):
assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
f'Node with label "{nodes[0].label}" does not all have the same type.'
assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
f'Node with label "{nodes[0].label}" does not agree on parameters.'
mutators.append(ManyChooseManyMutator(nodes[0].label))
return model, mutators
# utility functions
def _is_all_equal(lst): def _is_all_equal(lst):
last = None last = None
for x in lst: for x in lst:
...@@ -131,6 +219,16 @@ def _is_all_equal(lst): ...@@ -131,6 +219,16 @@ def _is_all_equal(lst):
return True return True
def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
result = {}
for node in nodes:
key = (node.label, node.operation.type)
if key not in result:
result[key] = []
result[key].append(node)
return list(result.values())
def _group_by_label(nodes: List[Node]) -> List[List[Node]]: def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
result = {} result = {}
for node in nodes: for node in nodes:
......
...@@ -9,7 +9,7 @@ from typing import Any ...@@ -9,7 +9,7 @@ from typing import Any
import json_tricks import json_tricks
from .utils import get_importable_name, get_module_name, import_ from .utils import get_importable_name, get_module_name, import_, reset_uid
def get_init_parameters_or_fail(obj, silently=False): def get_init_parameters_or_fail(obj, silently=False):
...@@ -83,9 +83,11 @@ class Translatable(abc.ABC): ...@@ -83,9 +83,11 @@ class Translatable(abc.ABC):
pass pass
def _create_wrapper_cls(cls, store_init_parameters=True): def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False):
class wrapper(cls): class wrapper(cls):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if reset_mutation_uid:
reset_uid('mutation')
if store_init_parameters: if store_init_parameters:
argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:] argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
full_args = {} full_args = {}
...@@ -149,3 +151,15 @@ def basic_unit(cls): ...@@ -149,3 +151,15 @@ def basic_unit(cls):
import torch.nn as nn import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
return serialize_cls(cls) return serialize_cls(cls)
def model_wrapper(cls):
"""
Wrap the model if you are using pure-python execution engine.
The wrapper serves two purposes:
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios.
"""
return _create_wrapper_cls(cls, reset_mutation_uid=True)
...@@ -6,13 +6,20 @@ Entrypoint for trials. ...@@ -6,13 +6,20 @@ Entrypoint for trials.
Assuming execution engine is BaseExecutionEngine. Assuming execution engine is BaseExecutionEngine.
""" """
import os import argparse
from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine
if __name__ == '__main__': if __name__ == '__main__':
if os.environ.get('CGO') == 'true': parser = argparse.ArgumentParser()
CGOExecutionEngine.trial_execute_graph() parser.add_argument('exec', choices=['base', 'py', 'cgo'])
else: args = parser.parse_args()
BaseExecutionEngine.trial_execute_graph() if args.exec == 'base':
from .execution.base import BaseExecutionEngine
engine = BaseExecutionEngine
elif args.exec == 'cgo':
from .execution.cgo_engine import CGOExecutionEngine
engine = CGOExecutionEngine
elif args.exec == 'py':
from .execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine
engine.trial_execute_graph()
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import inspect import inspect
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any, List, Dict
from pathlib import Path from pathlib import Path
...@@ -31,6 +31,10 @@ def uid(namespace: str = 'default') -> int: ...@@ -31,6 +31,10 @@ def uid(namespace: str = 'default') -> int:
return _last_uid[namespace] return _last_uid[namespace]
def reset_uid(namespace: str = 'default') -> None:
_last_uid[namespace] = 0
def get_module_name(cls_or_func): def get_module_name(cls_or_func):
module_name = cls_or_func.__module__ module_name = cls_or_func.__module__
if module_name == '__main__': if module_name == '__main__':
...@@ -61,3 +65,42 @@ def get_module_name(cls_or_func): ...@@ -61,3 +65,42 @@ def get_module_name(cls_or_func):
def get_importable_name(cls, relocate_module=False): def get_importable_name(cls, relocate_module=False):
module_name = get_module_name(cls) if relocate_module else cls.__module__ module_name = get_module_name(cls) if relocate_module else cls.__module__
return module_name + '.' + cls.__name__ return module_name + '.' + cls.__name__
class ContextStack:
"""
This is to maintain a globally-accessible context envinronment that is visible to everywhere.
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
get the corresponding value in the namespace.
"""
_stack: Dict[str, List[Any]] = defaultdict(list)
def __init__(self, key: str, value: Any):
self.key = key
self.value = value
def __enter__(self):
self.push(self.key, self.value)
return self
def __exit__(self, *args, **kwargs):
self.pop(self.key)
@classmethod
def push(cls, key: str, value: Any):
cls._stack[key].append(value)
@classmethod
def pop(cls, key: str) -> None:
cls._stack[key].pop()
@classmethod
def top(cls, key: str) -> Any:
assert cls._stack[key], 'Context is empty.'
return cls._stack[key][-1]
def get_current_context(key: str) -> Any:
return ContextStack.top(key)
...@@ -3,22 +3,24 @@ import torch.nn as nn ...@@ -3,22 +3,24 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch
class _model(nn.Module): class _model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.stem = stem() self.stem = stem()
self.flatten = torch.nn.Flatten()
self.fc1 = nn.Linear(1024, 256) self.fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.fc2 = nn.Linear(256, 10) self.fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.softmax = torch.nn.Softmax()
def forward(self, image): def forward(self, image):
stem = self.stem(image) stem = self.stem(image)
flatten = stem.view(stem.size(0), -1) flatten = self.flatten(stem)
fc1 = self.fc1(flatten) fc1 = self.fc1(flatten)
fc2 = self.fc2(fc1) fc2 = self.fc2(fc1)
softmax = F.softmax(fc2, -1) softmax = self.softmax(fc2)
return softmax return softmax
...@@ -26,10 +28,10 @@ class _model(nn.Module): ...@@ -26,10 +28,10 @@ class _model(nn.Module):
class stem(nn.Module): class stem(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5) self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
self.pool1 = nn.MaxPool2d(kernel_size=2) self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5) self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = nn.MaxPool2d(kernel_size=2) self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
def forward(self, *_inputs): def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0]) conv1 = self.conv1(_inputs[0])
......
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
"nodes": { "nodes": {
"stem": {"operation": {"type": "_cell", "cell_name": "stem"}}, "stem": {"operation": {"type": "_cell", "cell_name": "stem"}},
"flatten": {"operation": {"type": "Flatten"}}, "flatten": {"operation": {"type": "__torch__.torch.nn.Flatten"}},
"fc1": {"operation": {"type": "Dense", "parameters": {"out_features": 256, "in_features": 1024}}}, "fc1": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 256, "in_features": 1024}}},
"fc2": {"operation": {"type": "Dense", "parameters": {"out_features": 10, "in_features": 256}}}, "fc2": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 10, "in_features": 256}}},
"softmax": {"operation": {"type": "Softmax"}} "softmax": {"operation": {"type": "__torch__.torch.nn.Softmax"}}
}, },
"edges": [ "edges": [
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
"stem": { "stem": {
"nodes": { "nodes": {
"conv1": {"operation": {"type": "__torch__.Conv2d", "parameters": {"out_channels": 32, "in_channels": 1, "kernel_size": 5}}}, "conv1": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 32, "in_channels": 1, "kernel_size": 5}}},
"pool1": {"operation": {"type": "__torch__.MaxPool2d", "parameters": {"kernel_size": 2}}}, "pool1": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}},
"conv2": {"operation": {"type": "__torch__.Conv2d", "parameters": {"out_channels": 64, "in_channels": 32, "kernel_size": 5}}}, "conv2": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 64, "in_channels": 32, "kernel_size": 5}}},
"pool2": {"operation": {"type": "__torch__.MaxPool2d", "parameters": {"kernel_size": 2}}} "pool2": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}}
}, },
"edges": [ "edges": [
...@@ -36,26 +36,5 @@ ...@@ -36,26 +36,5 @@
{"head": ["conv2", null], "tail": ["pool2", null]}, {"head": ["conv2", null], "tail": ["pool2", null]},
{"head": ["pool2", null], "tail": ["_outputs", 0]} {"head": ["pool2", null], "tail": ["_outputs", 0]}
] ]
},
"_evaluator": {
"module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer",
"kwargs": {
"dataset_cls": "MNIST",
"dataset_kwargs": {
"root": "data/mnist",
"download": true
},
"dataloader_kwargs": {
"batch_size": 32
},
"optimizer_cls" : "SGD",
"optimizer_kwargs": {
"lr": 1e-3
},
"trainer_kwargs": {
"max_epochs": 1
}
}
} }
} }
import json import json
import os import os
import sys
import threading
import unittest import unittest
from pathlib import Path from pathlib import Path
import nni import nni.retiarii
from nni.retiarii import Model, submit_models from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.integration import RetiariiAdvisor, register_advisor from nni.retiarii.execution import set_execution_engine
from nni.retiarii.evaluator.pytorch import PyTorchImageClassificationTrainer from nni.retiarii.execution.base import BaseExecutionEngine
from nni.retiarii.utils import import_ from nni.retiarii.execution.python import PurePythonExecutionEngine
from nni.retiarii.integration import RetiariiAdvisor
@unittest.skip('Skipped in this version') class EngineTest(unittest.TestCase):
class CodeGenTest(unittest.TestCase): def test_codegen(self):
def test_mnist_example_pytorch(self): with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
with open('mnist_pytorch.json') as f:
model = Model._load(json.load(f)) model = Model._load(json.load(f))
script = model_to_pytorch_script(model) script = model_to_pytorch_script(model)
with open('debug_mnist_pytorch.py') as f: with open(self.enclosing_dir / 'debug_mnist_pytorch.py') as f:
reference_script = f.read() reference_script = f.read()
self.assertEqual(script.strip(), reference_script.strip()) self.assertEqual(script.strip(), reference_script.strip())
def test_base_execution_engine(self):
advisor = RetiariiAdvisor()
set_execution_engine(BaseExecutionEngine())
with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
model = Model._load(json.load(f))
submit_models(model, model)
@unittest.skip('Skipped in this version') advisor.stopping = True
class TrainerTest(unittest.TestCase): advisor.default_worker.join()
def test_trainer(self): advisor.assessor_worker.join()
sys.path.insert(0, Path(__file__).parent.as_posix())
Model = import_('debug_mnist_pytorch._model')
trainer = PyTorchImageClassificationTrainer(
Model(),
dataset_kwargs={'root': (Path(__file__).parent / 'data' / 'mnist').as_posix(), 'download': True},
dataloader_kwargs={'batch_size': 32},
optimizer_kwargs={'lr': 1e-3},
trainer_kwargs={'max_epochs': 1}
)
trainer.fit()
@unittest.skip('Skipped in this version')
class EngineTest(unittest.TestCase):
def test_submit_models(self): def test_py_execution_engine(self):
os.makedirs('generated', exist_ok=True)
from nni.runtime import protocol
protocol._out_file = open(Path(__file__).parent / 'generated/debug_protocol_out_file.py', 'wb')
advisor = RetiariiAdvisor() advisor = RetiariiAdvisor()
with open('mnist_pytorch.json') as f: set_execution_engine(PurePythonExecutionEngine())
model = Model._load(json.load(f)) model = Model._load({
'_model': {
'inputs': None,
'outputs': None,
'nodes': {
'layerchoice_1': {
'operation': {'type': 'LayerChoice', 'parameters': {'candidates': ['0', '1']}}
}
},
'edges': []
}
})
model.python_class = object
submit_models(model, model) submit_models(model, model)
advisor.stopping = True advisor.stopping = True
advisor.default_worker.join() advisor.default_worker.join()
advisor.assessor_worker.join() advisor.assessor_worker.join()
def test_execution_engine(self): def setUp(self) -> None:
pass self.enclosing_dir = Path(__file__).parent
os.makedirs(self.enclosing_dir / 'generated', exist_ok=True)
from nni.runtime import protocol
protocol._out_file = open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb')
def tearDown(self) -> None:
from nni.runtime import protocol
protocol._out_file.close()
nni.retiarii.execution.api._execution_engine = None
nni.retiarii.integration_api._advisor = None
...@@ -8,7 +8,10 @@ import torch.nn.functional as F ...@@ -8,7 +8,10 @@ import torch.nn.functional as F
from nni.retiarii import Sampler, basic_unit from nni.retiarii import Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation from nni.retiarii.execution.python import _unpack_if_only_one
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
from nni.retiarii.serializer import model_wrapper
from nni.retiarii.utils import ContextStack
class EnumerateSampler(Sampler): class EnumerateSampler(Sampler):
...@@ -44,7 +47,7 @@ class MutableConv(nn.Module): ...@@ -44,7 +47,7 @@ class MutableConv(nn.Module):
return self.conv2(x) return self.conv2(x)
class TestHighLevelAPI(unittest.TestCase): class GraphIR(unittest.TestCase):
def _convert_to_ir(self, model): def _convert_to_ir(self, model):
script_module = torch.jit.script(model) script_module = torch.jit.script(model)
...@@ -56,7 +59,19 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -56,7 +59,19 @@ class TestHighLevelAPI(unittest.TestCase):
exec(model_code + '\n\nconverted_model = _model()', exec_vars) exec(model_code + '\n\nconverted_model = _model()', exec_vars)
return exec_vars['converted_model'] return exec_vars['converted_model']
def _get_model_with_mutators(self, pytorch_model):
model = self._convert_to_ir(pytorch_model)
mutators = process_inline_mutation(model)
return model, mutators
def get_serializer(self):
def dummy(cls):
return cls
return dummy
def test_layer_choice(self): def test_layer_choice(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -68,8 +83,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -68,8 +83,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x): def forward(self, x):
return self.module(x) return self.module(x)
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
...@@ -80,6 +94,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -80,6 +94,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 5, 3, 3])) torch.Size([1, 5, 3, 3]))
def test_input_choice(self): def test_input_choice(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -92,8 +107,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -92,8 +107,7 @@ class TestHighLevelAPI(unittest.TestCase):
x2 = self.conv2(x) x2 = self.conv2(x)
return self.input([x1, x2]) return self.input([x1, x2])
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
...@@ -104,6 +118,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -104,6 +118,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 5, 3, 3])) torch.Size([1, 5, 3, 3]))
def test_chosen_inputs(self): def test_chosen_inputs(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self, reduction): def __init__(self, reduction):
super().__init__() super().__init__()
...@@ -117,8 +132,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -117,8 +132,7 @@ class TestHighLevelAPI(unittest.TestCase):
return self.input([x1, x2]) return self.input([x1, x2])
for reduction in ['none', 'sum', 'mean', 'concat']: for reduction in ['none', 'sum', 'mean', 'concat']:
model = self._convert_to_ir(Net(reduction)) model, mutators = self._get_model_with_mutators(Net(reduction))
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model = mutator.apply(model) model = mutator.apply(model)
...@@ -133,6 +147,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -133,6 +147,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertEqual(result.size(), torch.Size([1, 3, 3, 3])) self.assertEqual(result.size(), torch.Size([1, 3, 3, 3]))
def test_value_choice(self): def test_value_choice(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -142,8 +157,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -142,8 +157,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x): def forward(self, x):
return self.conv(x, self.index()) return self.conv(x, self.index())
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
...@@ -154,6 +168,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -154,6 +168,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 5, 3, 3])) torch.Size([1, 5, 3, 3]))
def test_value_choice_as_parameter(self): def test_value_choice_as_parameter(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -162,8 +177,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -162,8 +177,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
...@@ -174,6 +188,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -174,6 +188,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 5, 1, 1])) torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self): def test_value_choice_as_parameter(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -182,8 +197,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -182,8 +197,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
...@@ -194,6 +208,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -194,6 +208,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 5, 1, 1])) torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self): def test_value_choice_as_parameter(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -202,8 +217,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -202,8 +217,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 2) self.assertEqual(len(mutators), 2)
mutators[0].bind_sampler(EnumerateSampler()) mutators[0].bind_sampler(EnumerateSampler())
mutators[1].bind_sampler(EnumerateSampler()) mutators[1].bind_sampler(EnumerateSampler())
...@@ -214,6 +228,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -214,6 +228,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 8, 1, 1])) torch.Size([1, 8, 1, 1]))
def test_value_choice_as_parameter_shared(self): def test_value_choice_as_parameter_shared(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -223,8 +238,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -223,8 +238,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x): def forward(self, x):
return self.conv1(x) + self.conv2(x) return self.conv1(x) + self.conv2(x)
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
...@@ -235,6 +249,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -235,6 +249,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 8, 5, 5])) torch.Size([1, 8, 5, 5]))
def test_value_choice_in_functional(self): def test_value_choice_in_functional(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -243,8 +258,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -243,8 +258,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x): def forward(self, x):
return F.dropout(x, self.dropout_rate()) return F.dropout(x, self.dropout_rate())
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
...@@ -254,6 +268,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -254,6 +268,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_value_choice_in_layer_choice(self): def test_value_choice_in_layer_choice(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -265,8 +280,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -265,8 +280,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x): def forward(self, x):
return self.linear(x) return self.linear(x)
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 3) self.assertEqual(len(mutators), 3)
sz_counter = Counter() sz_counter = Counter()
sampler = RandomSampler() sampler = RandomSampler()
...@@ -278,6 +292,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -278,6 +292,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertEqual(len(sz_counter), 4) self.assertEqual(len(sz_counter), 4)
def test_shared(self): def test_shared(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self, shared=True): def __init__(self, shared=True):
super().__init__() super().__init__()
...@@ -294,16 +309,14 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -294,16 +309,14 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x): def forward(self, x):
return self.module1(x) + self.module2(x) return self.module1(x) + self.module2(x)
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
sampler = RandomSampler() sampler = RandomSampler()
mutator = mutators[0].bind_sampler(sampler) mutator = mutators[0].bind_sampler(sampler)
self.assertEqual(self._get_converted_pytorch_model(mutator.apply(model))(torch.randn(1, 3, 3, 3)).size(0), 1) self.assertEqual(self._get_converted_pytorch_model(mutator.apply(model))(torch.randn(1, 3, 3, 3)).size(0), 1)
self.assertEqual(sampler.counter, 1) self.assertEqual(sampler.counter, 1)
model = self._convert_to_ir(Net(shared=False)) model, mutators = self._get_model_with_mutators(Net(shared=False))
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 2) self.assertEqual(len(mutators), 2)
sampler = RandomSampler() sampler = RandomSampler()
# repeat test. Expectation: sometimes succeeds, sometimes fails. # repeat test. Expectation: sometimes succeeds, sometimes fails.
...@@ -321,6 +334,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -321,6 +334,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertLess(failed_count, 30) self.assertLess(failed_count, 30)
def test_valuechoice_access(self): def test_valuechoice_access(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -330,8 +344,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -330,8 +344,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler()) mutators[0].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5) input = torch.randn(1, 3, 5, 5)
...@@ -340,6 +353,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -340,6 +353,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(), self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
torch.Size([1, 8, 1, 1])) torch.Size([1, 8, 1, 1]))
@self.get_serializer()
class Net2(nn.Module): class Net2(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -354,14 +368,14 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -354,14 +368,14 @@ class TestHighLevelAPI(unittest.TestCase):
x = self.conv(x) x = self.conv(x)
return self.conv1(torch.cat((x, x), 1)) return self.conv1(torch.cat((x, x), 1))
model = self._convert_to_ir(Net2()) model, mutators = self._get_model_with_mutators(Net2())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler()) mutators[0].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5) input = torch.randn(1, 3, 5, 5)
self._get_converted_pytorch_model(mutators[0].apply(model))(input) self._get_converted_pytorch_model(mutators[0].apply(model))(input)
def test_valuechoice_access_functional(self): def test_valuechoice_access_functional(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -370,8 +384,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -370,8 +384,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x): def forward(self, x):
return F.dropout(x, self.dropout_rate()[0]) return F.dropout(x, self.dropout_rate()[0])
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
...@@ -381,6 +394,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -381,6 +394,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_valuechoice_access_functional_expression(self): def test_valuechoice_access_functional_expression(self):
@self.get_serializer()
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -391,8 +405,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -391,8 +405,7 @@ class TestHighLevelAPI(unittest.TestCase):
# ValueError: dropout probability has to be between 0 and 1, but got 1.05 # ValueError: dropout probability has to be between 0 and 1, but got 1.05
return F.dropout(x, self.dropout_rate()[0] - .1) return F.dropout(x, self.dropout_rate()[0] - .1)
model = self._convert_to_ir(Net()) model, mutators = self._get_model_with_mutators(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
...@@ -400,3 +413,29 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -400,3 +413,29 @@ class TestHighLevelAPI(unittest.TestCase):
self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)) self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3])) self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
class Python(GraphIR):
def _get_converted_pytorch_model(self, model_ir):
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history}
with ContextStack('fixed', mutation):
model = model_ir.python_class(**model_ir.python_init_params)
return model
def _get_model_with_mutators(self, pytorch_model):
return extract_mutation_from_pt_module(pytorch_model)
def get_serializer(self):
return model_wrapper
@unittest.skip
def test_value_choice(self): ...
@unittest.skip
def test_value_choice_in_functional(self): ...
@unittest.skip
def test_valuechoice_access_functional(self): ...
@unittest.skip
def test_valuechoice_access_functional_expression(self): ...
...@@ -60,7 +60,14 @@ def test_mutation(): ...@@ -60,7 +60,14 @@ def test_mutation():
model2 = mutator.apply(model1) model2 = mutator.apply(model1)
assert _get_pools(model2) == (global_pool, max_pool) assert _get_pools(model2) == (global_pool, max_pool)
assert model2.history == [model0, model1] assert len(model2.history) == 2
assert model2.history[0].from_ == model0
assert model2.history[0].to == model1
assert model2.history[1].from_ == model1
assert model2.history[1].to == model2
assert model2.history[0].mutator == mutator
assert model2.history[1].mutator == mutator
assert _get_pools(model0) == (max_pool, max_pool) assert _get_pools(model0) == (max_pool, max_pool)
assert _get_pools(model1) == (avg_pool, global_pool) assert _get_pools(model1) == (avg_pool, global_pool)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment