Unverified Commit fddc8adc authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Grid search, random and evolution strategy (#3377)

parent bd2543e3
...@@ -72,10 +72,16 @@ Oneshot Trainers ...@@ -72,10 +72,16 @@ Oneshot Trainers
Strategies Strategies
---------- ----------
.. autoclass:: nni.retiarii.strategies.RandomStrategy .. autoclass:: nni.retiarii.strategy.Random
:members: :members:
.. autoclass:: nni.retiarii.strategies.TPEStrategy .. autoclass:: nni.retiarii.strategy.GridSearch
:members:
.. autoclass:: nni.retiarii.strategy.RegularizedEvolution
:members:
.. autoclass:: nni.retiarii.strategy.TPEStrategy
:members: :members:
Retiarii Experiments Retiarii Experiments
......
...@@ -167,13 +167,13 @@ In the following table, we listed the available trainers and strategies. ...@@ -167,13 +167,13 @@ In the following table, we listed the available trainers and strategies.
- TPEStrategy - TPEStrategy
- DartsTrainer - DartsTrainer
* - Regression * - Regression
- RandomStrategy - Random
- EnasTrainer - EnasTrainer
* - * -
- - GridSearch
- ProxylessTrainer - ProxylessTrainer
* - * -
- - RegularizedEvolution
- SinglePathTrainer (RandomTrainer) - SinglePathTrainer (RandomTrainer)
There usage and API document can be found `here <./ApiReference>`__\. There usage and API document can be found `here <./ApiReference>`__\.
...@@ -204,7 +204,7 @@ After all the above are prepared, it is time to start an experiment to do the mo ...@@ -204,7 +204,7 @@ After all the above are prepared, it is time to start an experiment to do the mo
.. code-block:: python .. code-block:: python
exp = RetiariiExperiment(base_model, trainer, applied_mutators, simple_startegy) exp = RetiariiExperiment(base_model, trainer, applied_mutators, simple_strategy)
exp_config = RetiariiExeConfig('local') exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnasnet_search' exp_config.experiment_name = 'mnasnet_search'
exp_config.trial_concurrency = 2 exp_config.trial_concurrency = 2
......
...@@ -3,10 +3,12 @@ Customize A New Strategy ...@@ -3,10 +3,12 @@ Customize A New Strategy
To write a new strategy, you should inherit the base strategy class ``BaseStrategy``, then implement the member function ``run``. This member function takes ``base_model`` and ``applied_mutators`` as its input arguments. It can simply apply the user specified mutators in ``applied_mutators`` onto ``base_model`` to generate a new model. When a mutator is applied, it should be bound with a sampler (e.g., ``RandomSampler``). Every sampler implements the ``choice`` function which chooses value(s) from candidate values. The ``choice`` functions invoked in mutators are executed with the sampler. To write a new strategy, you should inherit the base strategy class ``BaseStrategy``, then implement the member function ``run``. This member function takes ``base_model`` and ``applied_mutators`` as its input arguments. It can simply apply the user specified mutators in ``applied_mutators`` onto ``base_model`` to generate a new model. When a mutator is applied, it should be bound with a sampler (e.g., ``RandomSampler``). Every sampler implements the ``choice`` function which chooses value(s) from candidate values. The ``choice`` functions invoked in mutators are executed with the sampler.
Below is a very simple random strategy, the complete code can be found :githublink:`here <nni/retiarii/strategies/random_strategy.py>`. Below is a very simple random strategy, which makes the choices completely random.
.. code-block:: python .. code-block:: python
from nni.retiarii import Sampler
class RandomSampler(Sampler): class RandomSampler(Sampler):
def choice(self, candidates, mutator, model, index): def choice(self, candidates, mutator, model, index):
return random.choice(candidates) return random.choice(candidates)
...@@ -31,6 +33,6 @@ Below is a very simple random strategy, the complete code can be found :githubli ...@@ -31,6 +33,6 @@ Below is a very simple random strategy, the complete code can be found :githubli
else: else:
time.sleep(2) time.sleep(2)
You can find that this strategy does not know the search space beforehand, it passively makes decisions every time ``choice`` is invoked from mutators. If a strategy wants to know the whole search space before making any decision (e.g., TPE, SMAC), it can use ``dry_run`` function provided by ``Mutator`` to obtain the space. An example strategy can be found :githublink:`here <nni/retiarii/strategies/tpe_strategy.py>`. You can find that this strategy does not know the search space beforehand, it passively makes decisions every time ``choice`` is invoked from mutators. If a strategy wants to know the whole search space before making any decision (e.g., TPE, SMAC), it can use ``dry_run`` function provided by ``Mutator`` to obtain the space. An example strategy can be found :githublink:`here <nni/retiarii/strategy/tpe_strategy.py>`.
After generating a new model, the strategy can use our provided APIs (e.g., ``submit_models``, ``is_stopped_exec``) to submit the model and get its reported results. More APIs can be found in `API References <./ApiReference.rst>`__. After generating a new model, the strategy can use our provided APIs (e.g., ``submit_models``, ``is_stopped_exec``) to submit the model and get its reported results. More APIs can be found in `API References <./ApiReference.rst>`__.
\ No newline at end of file
...@@ -65,11 +65,11 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -65,11 +65,11 @@ class BaseExecutionEngine(AbstractExecutionEngine):
if self.resources <= 0: if self.resources <= 0:
_logger.warning('There is no available resource, but trial is submitted.') _logger.warning('There is no available resource, but trial is submitted.')
self.resources -= 1 self.resources -= 1
_logger.info('on_resource_used: %d', self.resources) _logger.info('Resource used. Remaining: %d', self.resources)
def _request_trial_jobs_callback(self, num_trials: int) -> None: def _request_trial_jobs_callback(self, num_trials: int) -> None:
self.resources += num_trials self.resources += num_trials
_logger.info('on_resource_available: %d', self.resources) _logger.info('New resource available. Remaining: %d', self.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None: def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id] model = self._running_models[trial_id]
......
...@@ -17,7 +17,7 @@ from ..graph import Model, TrainingConfig ...@@ -17,7 +17,7 @@ from ..graph import Model, TrainingConfig
from ..integration import RetiariiAdvisor from ..integration import RetiariiAdvisor
from ..mutator import Mutator from ..mutator import Mutator
from ..nn.pytorch.mutator import process_inline_mutation from ..nn.pytorch.mutator import process_inline_mutation
from ..strategies.strategy import BaseStrategy from ..strategy import BaseStrategy
from ..trainer.interface import BaseOneShotTrainer, BaseTrainer from ..trainer.interface import BaseOneShotTrainer, BaseTrainer
from ..utils import get_records from ..utils import get_records
......
...@@ -131,7 +131,7 @@ class Model: ...@@ -131,7 +131,7 @@ class Model:
new_model = Model(_internal=True) new_model = Model(_internal=True)
new_model._root_graph_name = self._root_graph_name new_model._root_graph_name = self._root_graph_name
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()} new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
new_model.training_config = copy.deepcopy(self.training_config) new_model.training_config = copy.deepcopy(self.training_config) # TODO this may be a problem when training config is large
new_model.history = self.history + [self] new_model.history = self.history + [self]
return new_model return new_model
......
from .tpe_strategy import TPEStrategy
from .random_strategy import RandomStrategy
import logging
import random
import time
from .. import Sampler, submit_models, query_available_resources
from .strategy import BaseStrategy
_logger = logging.getLogger(__name__)
class RandomSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return random.choice(candidates)
class RandomStrategy(BaseStrategy):
def __init__(self):
self.random_sampler = RandomSampler()
def run(self, base_model, applied_mutators):
_logger.info('stargety start...')
while True:
avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model
_logger.info('apply mutators...')
_logger.info('mutators: %s', str(applied_mutators))
for mutator in applied_mutators:
mutator.bind_sampler(self.random_sampler)
model = mutator.apply(model)
# run models
submit_models(model)
else:
time.sleep(2)
from .base import BaseStrategy
from .bruteforce import Random, GridSearch
from .evolution import RegularizedEvolution
from .tpe_strategy import TPEStrategy
import copy
import itertools
import logging
import random
import time
from typing import Any, Dict, List
from .. import Sampler, submit_models, query_available_resources
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model
_logger = logging.getLogger(__name__)
def grid_generator(search_space: Dict[Any, List[Any]], shuffle=True):
keys = list(search_space.keys())
search_space_values = copy.deepcopy(list(search_space.values()))
if shuffle:
for values in search_space_values:
random.shuffle(values)
for values in itertools.product(*search_space_values):
yield {key: value for key, value in zip(keys, values)}
def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500):
keys = list(search_space.keys())
history = set()
search_space_values = copy.deepcopy(list(search_space.values()))
while True:
for retry_count in range(retries):
selected = [random.choice(v) for v in search_space_values]
if not dedup:
break
selected = tuple(selected)
if selected not in history:
history.add(selected)
break
if retry_count + 1 == retries:
_logger.info('Random generation has run out of patience. There is nothing to search. Exiting.')
return
yield {key: value for key, value in zip(keys, selected)}
class GridSearch(BaseStrategy):
"""
Traverse the search space and try all the possible combinations one by one.
Parameters
----------
shuffle : bool
Shuffle the order in a candidate list, so that they are tried in a random order. Default: true.
"""
def __init__(self, shuffle=True):
self._polling_interval = 2.
self.shuffle = shuffle
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in grid_generator(search_space, shuffle=self.shuffle):
_logger.info('New model created. Waiting for resource. %s', str(sample))
if query_available_resources() <= 0:
time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample))
class _RandomSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return random.choice(candidates)
class Random(BaseStrategy):
"""
Random search on the search space.
Parameters
----------
variational : bool
Do not dry run to get the full search space. Used when the search space has variational size or candidates. Default: false.
dedup : bool
Do not try the same configuration twice. When variational is true, deduplication is not supported. Default: true.
"""
def __init__(self, variational=False, dedup=True):
self.variational = variational
self.dedup = dedup
if variational and dedup:
raise ValueError('Dedup is not supported in variational mode.')
self.random_sampler = _RandomSampler()
self._polling_interval = 2.
def run(self, base_model, applied_mutators):
if self.variational:
_logger.info('Random search running in variational mode.')
sampler = _RandomSampler()
for mutator in applied_mutators:
mutator.bind_sampler(sampler)
while True:
avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model
for mutator in applied_mutators:
model = mutator.apply(model)
_logger.info('New model created. Applied mutators are: %s', str(applied_mutators))
submit_models(model)
else:
time.sleep(self._polling_interval)
else:
_logger.info('Random search running in fixed size mode. Dedup: %s.', 'on' if self.dedup else 'off')
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in random_generator(search_space, dedup=self.dedup):
_logger.info('New model created. Waiting for resource. %s', str(sample))
if query_available_resources() <= 0:
time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample))
import collections
import dataclasses
import logging
import random
import time
from ..execution import query_available_resources, submit_models
from ..graph import ModelStatus
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model
_logger = logging.getLogger(__name__)
@dataclasses.dataclass
class Individual:
"""
A class that represents an individual.
Holds two attributes, where ``x`` is the model and ``y`` is the metric (e.g., accuracy).
"""
x: dict
y: float
class RegularizedEvolution(BaseStrategy):
"""
Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Classifier Architecture Search".
Parameters
----------
optimize_mode : str
Can be one of "maximize" and "minimize". Default: maximize.
population_size : int
The number of individuals to keep in the population. Default: 100.
cycles : int
The number of cycles (trials) the algorithm should run for. Default: 20000.
sample_size : int
The number of individuals that should participate in each tournament. Default: 25.
mutation_prob : float
Probability that mutation happens in each dim. Default: 0.05
on_failure : str
Can be one of "ignore" and "worst". If "ignore", simply give up the model and find a new one.
If "worst", mark the model as -inf (if maximize, inf if minimize), so that the algorithm "learns" to avoid such model.
Default: ignore.
"""
def __init__(self, optimize_mode='maximize', population_size=100, sample_size=25, cycles=20000,
mutation_prob=0.05, on_failure='ignore'):
assert optimize_mode in ['maximize', 'minimize']
assert on_failure in ['ignore', 'worst']
assert sample_size < population_size
self.optimize_mode = optimize_mode
self.population_size = population_size
self.sample_size = sample_size
self.cycles = cycles
self.mutation_prob = mutation_prob
self.on_failure = on_failure
self._worst = float('-inf') if self.optimize_mode == 'maximize' else float('inf')
self._success_count = 0
self._population = collections.deque()
self._running_models = []
self._polling_interval = 2.
def random(self, search_space):
return {k: random.choice(v) for k, v in search_space.items()}
def mutate(self, parent, search_space):
child = {}
for k, v in parent.items():
if random.uniform(0, 1) < self.mutation_prob:
# NOTE: we do not exclude the original choice here for simplicity,
# which is slightly different from the original paper.
child[k] = random.choice(search_space[k])
else:
child[k] = v
return child
def best_parent(self):
samples = [p for p in self._population] # copy population
random.shuffle(samples)
samples = list(samples)[:self.sample_size]
if self.optimize_mode == 'maximize':
parent = max(samples, key=lambda sample: sample.y)
else:
parent = min(samples, key=lambda sample: sample.y)
return parent.x
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
# Run the first population regardless concurrency
_logger.info('Initializing the first population.')
while len(self._population) + len(self._running_models) <= self.population_size:
# try to submit new models
while len(self._population) + len(self._running_models) < self.population_size:
config = self.random(search_space)
self._submit_config(config, base_model, applied_mutators)
# collect results
self._move_succeeded_models_to_population()
self._remove_failed_models_from_running_list()
time.sleep(self._polling_interval)
if len(self._population) >= self.population_size:
break
# Resource-aware mutation of models
_logger.info('Running mutations.')
while self._success_count + len(self._running_models) <= self.cycles:
# try to submit new models
while query_available_resources() > 0 and self._success_count + len(self._running_models) < self.cycles:
config = self.mutate(self.best_parent(), search_space)
self._submit_config(config, base_model, applied_mutators)
# collect results
self._move_succeeded_models_to_population()
self._remove_failed_models_from_running_list()
time.sleep(self._polling_interval)
if self._success_count >= self.cycles:
break
def _submit_config(self, config, base_model, mutators):
_logger.info('Model submitted to running queue: %s', config)
model = get_targeted_model(base_model, mutators, config)
submit_models(model)
self._running_models.append((config, model))
return model
def _move_succeeded_models_to_population(self):
completed_indices = []
for i, (config, model) in enumerate(self._running_models):
metric = None
if self.on_failure == 'worst' and model.status == ModelStatus.Failed:
metric = self._worst
elif model.status == ModelStatus.Trained:
metric = model.metric
if metric is not None:
individual = Individual(config, metric)
_logger.info('Individual created: %s', str(individual))
self._population.append(individual)
if len(self._population) > self.population_size:
self._population.popleft()
completed_indices.append(i)
for i in completed_indices[::-1]:
# delete from end to start so that the index number will not be affected.
self._success_count += 1
self._running_models.pop(i)
def _remove_failed_models_from_running_list(self):
# This is only done when on_failure policy is set to "ignore".
# Otherwise, failed models will be treated as inf when processed.
if self.on_failure == 'ignore':
number_of_failed_models = len([g for g in self._running_models if g[1].status == ModelStatus.Failed])
self._running_models = [g for g in self._running_models if g[1].status != ModelStatus.Failed]
if number_of_failed_models > 0:
_logger.info('%d failed models are ignored. Will retry.', number_of_failed_models)
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
from .. import Sampler, submit_models, query_available_resources, is_stopped_exec from .. import Sampler, submit_models, query_available_resources, is_stopped_exec
from .strategy import BaseStrategy from .base import BaseStrategy
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -50,16 +50,14 @@ class TPEStrategy(BaseStrategy): ...@@ -50,16 +50,14 @@ class TPEStrategy(BaseStrategy):
sample_space.extend(recorded_candidates) sample_space.extend(recorded_candidates)
self.tpe_sampler.update_sample_space(sample_space) self.tpe_sampler.update_sample_space(sample_space)
_logger.info('stargety start...') _logger.info('TPE strategy has been started.')
while True: while True:
avail_resource = query_available_resources() avail_resource = query_available_resources()
if avail_resource > 0: if avail_resource > 0:
model = base_model model = base_model
_logger.info('apply mutators...') _logger.info('New model created. Applied mutators: %s', str(applied_mutators))
_logger.info('mutators: %s', str(applied_mutators))
self.tpe_sampler.generate_samples(self.model_id) self.tpe_sampler.generate_samples(self.model_id)
for mutator in applied_mutators: for mutator in applied_mutators:
_logger.info('mutate model...')
mutator.bind_sampler(self.tpe_sampler) mutator.bind_sampler(self.tpe_sampler)
model = mutator.apply(model) model = mutator.apply(model)
# run models # run models
......
import collections
from typing import Dict, Any, List
from ..graph import Model
from ..mutator import Mutator, Sampler
class _FixedSampler(Sampler):
def __init__(self, sample):
self.sample = sample
def choice(self, candidates, mutator, model, index):
return self.sample[(mutator, index)]
def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, List[Any]]:
search_space = collections.OrderedDict()
for mutator in mutators:
recorded_candidates, model = mutator.dry_run(model)
for i, candidates in enumerate(recorded_candidates):
search_space[(mutator, i)] = candidates
return search_space
def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) -> Model:
sampler = _FixedSampler(sample)
model = base_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
return model
...@@ -5,9 +5,9 @@ import torch ...@@ -5,9 +5,9 @@ import torch
from pathlib import Path from pathlib import Path
import nni.retiarii.trainer.pytorch.lightning as pl import nni.retiarii.trainer.pytorch.lightning as pl
import nni.retiarii.strategy as strategy
from nni.retiarii import blackbox_module as bm from nni.retiarii import blackbox_module as bm
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategies import TPEStrategy, RandomStrategy
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
...@@ -33,9 +33,9 @@ if __name__ == '__main__': ...@@ -33,9 +33,9 @@ if __name__ == '__main__':
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=1, limit_train_batches=0.2) max_epochs=1, limit_train_batches=0.2)
simple_startegy = RandomStrategy() simple_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, trainer, [], simple_startegy) exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)
exp_config = RetiariiExeConfig('local') exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'darts_search' exp_config.experiment_name = 'darts_search'
......
...@@ -8,8 +8,7 @@ from pathlib import Path ...@@ -8,8 +8,7 @@ from pathlib import Path
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment.pytorch import RetiariiExperiment
from nni.retiarii.strategies import TPEStrategy
from nni.retiarii.trainer.pytorch import DartsTrainer from nni.retiarii.trainer.pytorch import DartsTrainer
from darts_model import CNN from darts_model import CNN
......
...@@ -9,7 +9,7 @@ import nni.retiarii.trainer.pytorch.lightning as pl ...@@ -9,7 +9,7 @@ import nni.retiarii.trainer.pytorch.lightning as pl
from nni.retiarii import blackbox_module as bm from nni.retiarii import blackbox_module as bm
from base_mnasnet import MNASNet from base_mnasnet import MNASNet
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategies import TPEStrategy from nni.retiarii.strategy import TPEStrategy
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
...@@ -46,9 +46,9 @@ if __name__ == '__main__': ...@@ -46,9 +46,9 @@ if __name__ == '__main__':
BlockMutator('mutable_1') BlockMutator('mutable_1')
] ]
simple_startegy = TPEStrategy() simple_strategy = TPEStrategy()
exp = RetiariiExperiment(base_model, trainer, applied_mutators, simple_startegy) exp = RetiariiExperiment(base_model, trainer, applied_mutators, simple_strategy)
exp_config = RetiariiExeConfig('local') exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnasnet_search' exp_config.experiment_name = 'mnasnet_search'
......
import random import random
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
import nni.retiarii.strategy as strategy
import nni.retiarii.trainer.pytorch.lightning as pl import nni.retiarii.trainer.pytorch.lightning as pl
import torch.nn.functional as F import torch.nn.functional as F
from nni.retiarii import blackbox_module as bm from nni.retiarii import blackbox_module as bm
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.strategies import RandomStrategy
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
...@@ -42,9 +42,9 @@ if __name__ == '__main__': ...@@ -42,9 +42,9 @@ if __name__ == '__main__':
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=2) max_epochs=2)
simple_startegy = RandomStrategy() simple_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, trainer, [], simple_startegy) exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)
exp_config = RetiariiExeConfig('local') exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search' exp_config.experiment_name = 'mnist_search'
......
import random
import time
import threading
from typing import *
import nni.retiarii.execution.api
import nni.retiarii.nn.pytorch as nn
import nni.retiarii.strategy as strategy
import torch
import torch.nn.functional as F
from nni.retiarii import Model
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.execution import wait_models
from nni.retiarii.execution.interface import AbstractExecutionEngine, WorkerInfo, MetricData, AbstractGraphListener
from nni.retiarii.graph import DebugTraining, ModelStatus
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
class MockExecutionEngine(AbstractExecutionEngine):
def __init__(self, failure_prob=0.):
self.models = []
self.failure_prob = failure_prob
self._resource_left = 4
def _model_complete(self, model: Model):
time.sleep(random.uniform(0, 1))
if random.uniform(0, 1) < self.failure_prob:
model.status = ModelStatus.Failed
else:
model.metric = random.uniform(0, 1)
model.status = ModelStatus.Trained
self._resource_left += 1
def submit_models(self, *models: Model) -> None:
for model in models:
self.models.append(model)
self._resource_left -= 1
threading.Thread(target=self._model_complete, args=(model, )).start()
def query_available_resource(self) -> Union[List[WorkerInfo], int]:
return self._resource_left
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
pass
def trial_execute_graph(cls) -> MetricData:
pass
def _reset_execution_engine(engine=None):
nni.retiarii.execution.api._execution_engine = engine
class Net(nn.Module):
def __init__(self, hidden_size=32):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.LayerChoice([
nn.Linear(4*4*50, hidden_size, bias=True),
nn.Linear(4*4*50, hidden_size, bias=False)
])
self.fc2 = nn.LayerChoice([
nn.Linear(hidden_size, 10, bias=False),
nn.Linear(hidden_size, 10, bias=True)
])
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def _get_model_and_mutators():
base_model = Net()
script_module = torch.jit.script(base_model)
base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.training_config = DebugTraining()
mutators = process_inline_mutation(base_model_ir)
return base_model_ir, mutators
def test_grid_search():
gridsearch = strategy.GridSearch()
engine = MockExecutionEngine()
_reset_execution_engine(engine)
gridsearch.run(*_get_model_and_mutators())
wait_models(*engine.models)
selection = set()
for model in engine.models:
selection.add((
model.get_node_by_name('_model__fc1').operation.parameters['bias'],
model.get_node_by_name('_model__fc2').operation.parameters['bias']
))
assert len(selection) == 4
_reset_execution_engine()
def test_random_search():
random = strategy.Random()
engine = MockExecutionEngine()
_reset_execution_engine(engine)
random.run(*_get_model_and_mutators())
wait_models(*engine.models)
selection = set()
for model in engine.models:
selection.add((
model.get_node_by_name('_model__fc1').operation.parameters['bias'],
model.get_node_by_name('_model__fc2').operation.parameters['bias']
))
assert len(selection) == 4
_reset_execution_engine()
def test_evolution():
evolution = strategy.RegularizedEvolution(population_size=5, sample_size=3, cycles=10, mutation_prob=0.5, on_failure='ignore')
engine = MockExecutionEngine(failure_prob=0.2)
_reset_execution_engine(engine)
evolution.run(*_get_model_and_mutators())
wait_models(*engine.models)
_reset_execution_engine()
evolution = strategy.RegularizedEvolution(population_size=5, sample_size=3, cycles=10, mutation_prob=0.5, on_failure='worst')
engine = MockExecutionEngine(failure_prob=0.4)
_reset_execution_engine(engine)
evolution.run(*_get_model_and_mutators())
wait_models(*engine.models)
_reset_execution_engine()
if __name__ == '__main__':
test_grid_search()
test_random_search()
test_evolution()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment