Unverified Commit a0fd0036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #5036 from microsoft/promote-retiarii-to-nas

[DO NOT SQUASH] Promote retiarii to NAS
parents d6dcb483 bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import inspect
import os
import warnings
from typing import Any, TypeVar, Type
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.common.serializer import is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
'is_basic_unit', 'is_model_wrapped']
T = TypeVar('T')
def get_init_parameters_or_fail(obj: Any):
if is_traceable(obj):
return obj.trace_kwargs
raise ValueError(f'Object {obj} needs to be serializable but `trace_kwargs` is not available. '
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
'If it is a customized module, please to decorate it with @basic_unit. '
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
'try to use @nni.trace.')
def serialize(cls, *args, **kwargs):
"""
To create an serializable instance inline without decorator. For example,
.. code-block:: python
self.op = serialize(MyCustomOp, hidden_units=128)
"""
warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
'Try to use nni.trace, e.g., nni.trace(torch.optim.Adam)(learning_rate=1e-4) instead.',
category=DeprecationWarning)
return trace(cls)(*args, **kwargs)
def serialize_cls(cls):
"""
To create an serializable class.
"""
warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
'Try to use nni.trace instead.', category=DeprecationWarning)
return trace(cls)
def basic_unit(cls: T, basic_unit_tag: bool = True) -> T:
"""
To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
``basic_unit_tag`` is true by default. If set to false, it will not be explicitly mark as a basic unit, and
graph parser will continue to parse. Currently, this is to handle a special case in ``nn.Sequential``.
Although ``basic_unit`` calls ``trace`` in its implementation, it is not for serialization. Rather, it is meant
to capture the initialization arguments for mutation. Also, graph execution engine will stop digging into the inner
modules when it reaches a module that is decorated with ``basic_unit``.
.. code-block:: python
@basic_unit
class PrimitiveOp(nn.Module):
...
"""
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
if _check_wrapped(cls, 'basic_unit'):
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' # type: ignore
cls = trace(cls)
cls._nni_basic_unit = basic_unit_tag # type: ignore
_torchscript_patch(cls)
return cls
def model_wrapper(cls: T) -> T:
"""
Wrap the base model (search space). For example,
.. code-block:: python
@model_wrapper
class MyModel(nn.Module):
...
The wrapper serves two purposes:
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in namespace so that the auto label counting in each model stably starts from zero.
Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
But in future, we might enforce ``@model_wrapper`` to be required for base model.
"""
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
if _check_wrapped(cls, 'model_wrapper'):
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module) # type: ignore
# subclass can still use trace info
wrapper = trace(cls, inheritable=True)
class reset_wrapper(wrapper):
def __init__(self, *args, **kwargs):
self._model_namespace = ModelNamespace()
with self._model_namespace:
super().__init__(*args, **kwargs)
_copy_class_wrapper_attributes(wrapper, reset_wrapper)
reset_wrapper.__wrapped__ = getattr(wrapper, '__wrapped__', wrapper)
reset_wrapper._nni_model_wrapper = True
reset_wrapper._traced = True
_torchscript_patch(cls)
return reset_wrapper
def is_basic_unit(cls_or_instance) -> bool:
if not inspect.isclass(cls_or_instance):
cls_or_instance = cls_or_instance.__class__
return getattr(cls_or_instance, '_nni_basic_unit', False)
def is_model_wrapped(cls_or_instance) -> bool:
if not inspect.isclass(cls_or_instance):
cls_or_instance = cls_or_instance.__class__
return getattr(cls_or_instance, '_nni_model_wrapper', False)
def _check_wrapped(cls: Type, rewrap: str) -> bool:
wrapped = None
if is_model_wrapped(cls):
wrapped = 'model_wrapper'
elif is_basic_unit(cls):
wrapped = 'basic_unit'
elif is_wrapped_with_trace(cls):
wrapped = 'nni.trace'
if wrapped:
if wrapped != rewrap:
raise TypeError(f'{cls} is already wrapped with {wrapped}. Cannot rewrap with {rewrap}.')
return True
return False
def _torchscript_patch(cls) -> None:
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import torch
if hasattr(cls, '_get_nni_attr'): # could not exist on non-linux
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
if hasattr(cls, 'trace_symbol'):
# these must all exist or all non-exist
try:
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
cls.trace_copy = torch.jit.ignore(cls.trace_copy)
except AttributeError as e:
if 'property' in str(e):
raise RuntimeError('Trace on PyTorch module failed. Your PyTorch version might be outdated. '
'Please try to upgrade PyTorch.')
raise
from nni.nas.utils.serializer import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
from typing import List, Any
# pylint: disable=wildcard-import,unused-wildcard-import
from ..graph import Model
from ..mutator import Mutator
class BaseStrategy(abc.ABC):
@abc.abstractmethod
def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
pass
def export_top_models(self, top_k: int) -> List[Any]:
raise NotImplementedError('"export_top_models" is not implemented.')
from nni.nas.strategy.base import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import itertools
import logging
import random
import time
from typing import Any, Dict, List, Sequence, Optional
# pylint: disable=wildcard-import,unused-wildcard-import
from .. import InvalidMutation, Sampler, submit_models, query_available_resources, budget_exhausted
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model, filter_model
_logger = logging.getLogger(__name__)
def grid_generator(search_space: Dict[Any, List[Any]], shuffle=True):
keys = list(search_space.keys())
search_space_values = copy.deepcopy(list(search_space.values()))
if shuffle:
for values in search_space_values:
random.shuffle(values)
for values in itertools.product(*search_space_values):
yield {key: value for key, value in zip(keys, values)}
def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500):
keys = list(search_space.keys())
history = set()
search_space_values = copy.deepcopy(list(search_space.values()))
while True:
selected: Optional[Sequence[int]] = None
for retry_count in range(retries):
selected = [random.choice(v) for v in search_space_values]
if not dedup:
break
selected = tuple(selected)
if selected not in history:
history.add(selected)
break
if retry_count + 1 == retries:
_logger.debug('Random generation has run out of patience. There is nothing to search. Exiting.')
return
assert selected is not None, 'Retry attempts exhausted.'
yield {key: value for key, value in zip(keys, selected)}
class GridSearch(BaseStrategy):
"""
Traverse the search space and try all the possible combinations one by one.
Parameters
----------
shuffle : bool
Shuffle the order in a candidate list, so that they are tried in a random order. Default: true.
"""
def __init__(self, shuffle=True):
self._polling_interval = 2.
self.shuffle = shuffle
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in grid_generator(search_space, shuffle=self.shuffle):
_logger.debug('New model created. Waiting for resource. %s', str(sample))
while query_available_resources() <= 0:
if budget_exhausted():
return
time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample))
class _RandomSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return random.choice(candidates)
class Random(BaseStrategy):
"""
Random search on the search space.
Parameters
----------
variational : bool
Do not dry run to get the full search space. Used when the search space has variational size or candidates. Default: false.
dedup : bool
Do not try the same configuration twice. When variational is true, deduplication is not supported. Default: true.
model_filter: Callable[[Model], bool]
Feed the model and return a bool. This will filter the models in search space and select which to submit.
"""
def __init__(self, variational=False, dedup=True, model_filter=None):
self.variational = variational
self.dedup = dedup
if variational and dedup:
raise ValueError('Dedup is not supported in variational mode.')
self.random_sampler = _RandomSampler()
self._polling_interval = 2.
self.filter = model_filter
def run(self, base_model, applied_mutators):
if self.variational:
_logger.info('Random search running in variational mode.')
sampler = _RandomSampler()
for mutator in applied_mutators:
mutator.bind_sampler(sampler)
while True:
avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model
for mutator in applied_mutators:
model = mutator.apply(model)
_logger.debug('New model created. Applied mutators are: %s', str(applied_mutators))
if filter_model(self.filter, model):
submit_models(model)
elif budget_exhausted():
break
else:
time.sleep(self._polling_interval)
else:
_logger.info('Random search running in fixed size mode. Dedup: %s.', 'on' if self.dedup else 'off')
search_space = dry_run_for_search_space(base_model, applied_mutators)
for sample in random_generator(search_space, dedup=self.dedup):
_logger.debug('New model created. Waiting for resource. %s', str(sample))
while query_available_resources() <= 0:
if budget_exhausted():
return
time.sleep(self._polling_interval)
_logger.debug('Still waiting for resource.')
try:
model = get_targeted_model(base_model, applied_mutators, sample)
if filter_model(self.filter, model):
_logger.debug('Submitting model: %s', model)
submit_models(model)
except InvalidMutation as e:
_logger.warning(f'Invalid mutation: {e}. Skip.')
from nni.nas.strategy.bruteforce import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
import dataclasses
import logging
import random
import time
# pylint: disable=wildcard-import,unused-wildcard-import
from ..execution import query_available_resources, submit_models
from ..graph import ModelStatus
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model, filter_model
_logger = logging.getLogger(__name__)
@dataclasses.dataclass
class Individual:
"""
A class that represents an individual.
Holds two attributes, where ``x`` is the model and ``y`` is the metric (e.g., accuracy).
"""
x: dict
y: float
class RegularizedEvolution(BaseStrategy):
"""
Algorithm for regularized evolution (i.e. aging evolution).
Follows "Algorithm 1" in Real et al. "Regularized Evolution for Image Classifier Architecture Search".
Parameters
----------
optimize_mode : str
Can be one of "maximize" and "minimize". Default: maximize.
population_size : int
The number of individuals to keep in the population. Default: 100.
cycles : int
The number of cycles (trials) the algorithm should run for. Default: 20000.
sample_size : int
The number of individuals that should participate in each tournament. Default: 25.
mutation_prob : float
Probability that mutation happens in each dim. Default: 0.05
on_failure : str
Can be one of "ignore" and "worst". If "ignore", simply give up the model and find a new one.
If "worst", mark the model as -inf (if maximize, inf if minimize), so that the algorithm "learns" to avoid such model.
Default: ignore.
model_filter: Callable[[Model], bool]
Feed the model and return a bool. This will filter the models in search space and select which to submit.
"""
def __init__(self, optimize_mode='maximize', population_size=100, sample_size=25, cycles=20000,
mutation_prob=0.05, on_failure='ignore', model_filter=None):
assert optimize_mode in ['maximize', 'minimize']
assert on_failure in ['ignore', 'worst']
assert sample_size < population_size
self.optimize_mode = optimize_mode
self.population_size = population_size
self.sample_size = sample_size
self.cycles = cycles
self.mutation_prob = mutation_prob
self.on_failure = on_failure
self._worst = float('-inf') if self.optimize_mode == 'maximize' else float('inf')
self._success_count = 0
self._population = collections.deque()
self._running_models = []
self._polling_interval = 2.
self.filter = model_filter
def random(self, search_space):
return {k: random.choice(v) for k, v in search_space.items()}
def mutate(self, parent, search_space):
child = {}
for k, v in parent.items():
if random.uniform(0, 1) < self.mutation_prob:
# NOTE: we do not exclude the original choice here for simplicity,
# which is slightly different from the original paper.
child[k] = random.choice(search_space[k])
else:
child[k] = v
return child
def best_parent(self):
samples = [p for p in self._population] # copy population
random.shuffle(samples)
samples = list(samples)[:self.sample_size]
if self.optimize_mode == 'maximize':
parent = max(samples, key=lambda sample: sample.y)
else:
parent = min(samples, key=lambda sample: sample.y)
return parent.x
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
# Run the first population regardless concurrency
_logger.info('Initializing the first population.')
while len(self._population) + len(self._running_models) <= self.population_size:
# try to submit new models
while len(self._population) + len(self._running_models) < self.population_size:
config = self.random(search_space)
self._submit_config(config, base_model, applied_mutators)
# collect results
self._move_succeeded_models_to_population()
self._remove_failed_models_from_running_list()
time.sleep(self._polling_interval)
if len(self._population) >= self.population_size:
break
# Resource-aware mutation of models
_logger.info('Running mutations.')
while self._success_count + len(self._running_models) <= self.cycles:
# try to submit new models
while query_available_resources() > 0 and self._success_count + len(self._running_models) < self.cycles:
config = self.mutate(self.best_parent(), search_space)
self._submit_config(config, base_model, applied_mutators)
# collect results
self._move_succeeded_models_to_population()
self._remove_failed_models_from_running_list()
time.sleep(self._polling_interval)
if self._success_count >= self.cycles:
break
def _submit_config(self, config, base_model, mutators):
_logger.debug('Model submitted to running queue: %s', config)
model = get_targeted_model(base_model, mutators, config)
if not filter_model(self.filter, model):
if self.on_failure == "worst":
model.status = ModelStatus.Failed
self._running_models.append((config, model))
else:
submit_models(model)
self._running_models.append((config, model))
return model
def _move_succeeded_models_to_population(self):
completed_indices = []
for i, (config, model) in enumerate(self._running_models):
metric = None
if self.on_failure == 'worst' and model.status == ModelStatus.Failed:
metric = self._worst
elif model.status == ModelStatus.Trained:
metric = model.metric
if metric is not None:
individual = Individual(config, metric)
_logger.debug('Individual created: %s', str(individual))
self._population.append(individual)
if len(self._population) > self.population_size:
self._population.popleft()
completed_indices.append(i)
for i in completed_indices[::-1]:
# delete from end to start so that the index number will not be affected.
self._success_count += 1
self._running_models.pop(i)
def _remove_failed_models_from_running_list(self):
# This is only done when on_failure policy is set to "ignore".
# Otherwise, failed models will be treated as inf when processed.
if self.on_failure == 'ignore':
number_of_failed_models = len([g for g in self._running_models if g[1].status == ModelStatus.Failed])
self._running_models = [g for g in self._running_models if g[1].status != ModelStatus.Failed]
if number_of_failed_models > 0:
_logger.info('%d failed models are ignored. Will retry.', number_of_failed_models)
from nni.nas.strategy.evolution import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import random
import string
# pylint: disable=wildcard-import,unused-wildcard-import
from .. import Sampler, codegen, utils
from ..execution.base import BaseGraphData
from ..execution.utils import get_mutation_summary
from .base import BaseStrategy
_logger = logging.getLogger(__name__)
class ChooseFirstSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return candidates[0]
class _LocalDebugStrategy(BaseStrategy):
"""
This class is supposed to be used internally, for debugging trial mutation
"""
def run_one_model(self, model):
mutation_summary = get_mutation_summary(model)
graph_data = BaseGraphData(codegen.pytorch.model_to_pytorch_script(model), model.evaluator, mutation_summary) # type: ignore
random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model/{random_str}.py'
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with open(file_name, 'w') as f:
f.write(graph_data.model_script)
model_cls = utils.import_(f'_generated_model.{random_str}._model')
graph_data.evaluator._execute(model_cls)
os.remove(file_name)
def run(self, base_model, applied_mutators):
_logger.info('local debug strategy has been started.')
model = base_model
_logger.debug('New model created. Applied mutators: %s', str(applied_mutators))
choose_first_sampler = ChooseFirstSampler()
for mutator in applied_mutators:
mutator.bind_sampler(choose_first_sampler)
model = mutator.apply(model)
# directly run models
self.run_one_model(model)
from nni.nas.strategy.debug import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base import BaseStrategy
# pylint: disable=wildcard-import,unused-wildcard-import
try:
from nni.retiarii.oneshot.pytorch.strategy import ( # pylint: disable=unused-import
DARTS, GumbelDARTS, Proxyless, ENAS, RandomOneShot
)
except ImportError as import_err:
_import_err = import_err
class ImportFailedStrategy(BaseStrategy):
def run(self, base_model, applied_mutators):
raise _import_err
# otherwise typing check will pointing to the wrong location
globals()['DARTS'] = ImportFailedStrategy
globals()['GumbelDARTS'] = ImportFailedStrategy
globals()['Proxyless'] = ImportFailedStrategy
globals()['ENAS'] = ImportFailedStrategy
globals()['RandomOneShot'] = ImportFailedStrategy
from nni.nas.strategy.oneshot import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Optional, Callable
# pylint: disable=wildcard-import,unused-wildcard-import
from .base import BaseStrategy
from .utils import dry_run_for_search_space
from ..execution import query_available_resources
try:
has_tianshou = True
import torch
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy, PPOPolicy # pylint: disable=unused-import
from ._rl_impl import ModelEvaluationEnv, MultiThreadEnvWorker, Preprocessor, Actor, Critic
except ImportError:
has_tianshou = False
_logger = logging.getLogger(__name__)
class PolicyBasedRL(BaseStrategy):
"""
Algorithm for policy-based reinforcement learning.
This is a wrapper of algorithms provided in tianshou (PPO by default),
and can be easily customized with other algorithms that inherit ``BasePolicy``
(e.g., `REINFORCE <https://link.springer.com/content/pdf/10.1007/BF00992696.pdf>`__
as in `this paper <https://arxiv.org/abs/1611.01578>`__).
Parameters
----------
max_collect : int
How many times collector runs to collect trials for RL. Default 100.
trial_per_collect : int
How many trials (trajectories) each time collector collects.
After each collect, trainer will sample batch from replay buffer and do the update. Default: 20.
policy_fn : function
Takes :class:`ModelEvaluationEnv` as input and return a policy.
See :meth:`PolicyBasedRL._default_policy_fn` for an example.
"""
def __init__(self, max_collect: int = 100, trial_per_collect = 20,
policy_fn: Optional[Callable[['ModelEvaluationEnv'], 'BasePolicy']] = None):
if not has_tianshou:
raise ImportError('`tianshou` is required to run RL-based strategy. '
'Please use "pip install tianshou" to install it beforehand.')
self.policy_fn = policy_fn or self._default_policy_fn
self.max_collect = max_collect
self.trial_per_collect = trial_per_collect
@staticmethod
def _default_policy_fn(env):
net = Preprocessor(env.observation_space)
actor = Actor(env.action_space, net)
critic = Critic(net)
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=1e-4)
return PPOPolicy(actor, critic, optim, torch.distributions.Categorical,
discount_factor=1., action_space=env.action_space)
def run(self, base_model, applied_mutators):
search_space = dry_run_for_search_space(base_model, applied_mutators)
concurrency = query_available_resources()
env_fn = lambda: ModelEvaluationEnv(base_model, applied_mutators, search_space)
policy = self.policy_fn(env_fn())
env = BaseVectorEnv([env_fn for _ in range(concurrency)], MultiThreadEnvWorker)
collector = Collector(policy, env, VectorReplayBuffer(20000, len(env)))
for cur_collect in range(1, self.max_collect + 1):
_logger.info('Collect [%d] Running...', cur_collect)
result = collector.collect(n_episode=self.trial_per_collect)
_logger.info('Collect [%d] Result: %s', cur_collect, str(result))
policy.update(0, collector.buffer, batch_size=64, repeat=5)
from nni.nas.strategy.rl import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import time
from typing import Optional
# pylint: disable=wildcard-import,unused-wildcard-import
from .. import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted
from .base import BaseStrategy
_logger = logging.getLogger(__name__)
class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'):
# Move import here to eliminate some warning messages about dill.
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
self.cur_sample: Optional[dict] = None
self.index: Optional[int] = None
self.total_parameters = {}
def update_sample_space(self, sample_space):
search_space = {}
for i, each in enumerate(sample_space):
search_space[str(i)] = {'_type': 'choice', '_value': each}
self.tpe_tuner.update_search_space(search_space)
def generate_samples(self, model_id):
self.cur_sample = self.tpe_tuner.generate_parameters(model_id)
self.total_parameters[model_id] = self.cur_sample
self.index = 0
def receive_result(self, model_id, result):
self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result)
def choice(self, candidates, mutator, model, index):
assert isinstance(self.index, int) and isinstance(self.cur_sample, dict)
chosen = self.cur_sample[str(self.index)]
self.index += 1
return chosen
class TPE(BaseStrategy):
"""
The Tree-structured Parzen Estimator (TPE) is a sequential model-based optimization (SMBO) approach.
Find the details in
`Algorithms for Hyper-Parameter Optimization <https://papers.nips.cc/paper/2011/file/86e8f7ab32cfd12577bc2619bc635690-Paper.pdf>`__.
SMBO methods sequentially construct models to approximate the performance of hyperparameters based on historical measurements,
and then subsequently choose new hyperparameters to test based on this model.
"""
def __init__(self):
self.tpe_sampler = TPESampler()
self.model_id = 0
self.running_models = {}
def run(self, base_model, applied_mutators):
sample_space = []
new_model = base_model
for mutator in applied_mutators:
recorded_candidates, new_model = mutator.dry_run(new_model)
sample_space.extend(recorded_candidates)
self.tpe_sampler.update_sample_space(sample_space)
_logger.info('TPE strategy has been started.')
while not budget_exhausted():
avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model
_logger.debug('New model created. Applied mutators: %s', str(applied_mutators))
self.tpe_sampler.generate_samples(self.model_id)
for mutator in applied_mutators:
mutator.bind_sampler(self.tpe_sampler)
model = mutator.apply(model)
# run models
submit_models(model)
self.running_models[self.model_id] = model
self.model_id += 1
else:
time.sleep(2)
_logger.debug('num of running models: %d', len(self.running_models))
to_be_deleted = []
for _id, _model in self.running_models.items():
if is_stopped_exec(_model):
if _model.metric is not None:
self.tpe_sampler.receive_result(_id, _model.metric)
_logger.debug('tpe receive results: %d, %s', _id, _model.metric)
to_be_deleted.append(_id)
for _id in to_be_deleted:
del self.running_models[_id]
# alias for backward compatibility
TPEStrategy = TPE
from nni.nas.strategy.hpo import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import collections
import logging
from typing import Dict, Any, List
from ..graph import Model
from ..mutator import Mutator, Sampler
# pylint: disable=wildcard-import,unused-wildcard-import
_logger = logging.getLogger(__name__)
class _FixedSampler(Sampler):
def __init__(self, sample):
self.sample = sample
def choice(self, candidates, mutator, model, index):
return self.sample[(mutator, index)]
def dry_run_for_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, List[Any]]:
search_space = collections.OrderedDict()
for mutator in mutators:
recorded_candidates, model = mutator.dry_run(model)
for i, candidates in enumerate(recorded_candidates):
search_space[(mutator, i)] = candidates
return search_space
def dry_run_for_formatted_search_space(model: Model, mutators: List[Mutator]) -> Dict[Any, Dict[Any, Any]]:
search_space = collections.OrderedDict()
for mutator in mutators:
recorded_candidates, model = mutator.dry_run(model)
if len(recorded_candidates) == 1:
search_space[mutator.label] = {'_type': 'choice', '_value': recorded_candidates[0]}
else:
for i, candidate in enumerate(recorded_candidates):
search_space[f'{mutator.label}_{i}'] = {'_type': 'choice', '_value': candidate}
return search_space
def get_targeted_model(base_model: Model, mutators: List[Mutator], sample: dict) -> Model:
sampler = _FixedSampler(sample)
model = base_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
return model
def filter_model(model_filter, ir_model):
if model_filter is not None:
_logger.debug(f'Check if model satisfies constraints.')
if model_filter(ir_model):
_logger.debug(f'Model satisfied. Submit the model.')
return True
else:
_logger.debug(f'Model unsatisfied. Discard the model.')
return False
else:
return True
from nni.nas.strategy.utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Entrypoint for trials.
# pylint: disable=wildcard-import,unused-wildcard-import
Assuming execution engine is BaseExecutionEngine.
"""
import argparse
from nni.nas.execution.trial_entry import main
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('exec', choices=['base', 'py', 'cgo', 'benchmark'])
args = parser.parse_args()
if args.exec == 'base':
from .execution.base import BaseExecutionEngine
engine = BaseExecutionEngine
elif args.exec == 'cgo':
from .execution.cgo_engine import CGOExecutionEngine
engine = CGOExecutionEngine
elif args.exec == 'py':
from .execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine
elif args.exec == 'benchmark':
from .execution.benchmark import BenchmarkExecutionEngine
engine = BenchmarkExecutionEngine
else:
raise ValueError(f'Unrecognized benchmark name: {args.exec}')
engine.trial_execute_graph()
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import inspect
import itertools
import warnings
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, List, Dict, cast
from pathlib import Path
# pylint: disable=wildcard-import,unused-wildcard-import
from nni.common.hpo_utils import ParameterSpec
__all__ = ['NoContextError', 'ContextStack', 'ModelNamespace', 'original_state_dict_hooks']
def import_(target: str, allow_none: bool = False) -> Any:
if target is None:
return None
path, identifier = target.rsplit('.', 1)
module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier)
_last_uid = defaultdict(int)
_DEFAULT_MODEL_NAMESPACE = 'model'
def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1
return _last_uid[namespace]
def reset_uid(namespace: str = 'default') -> None:
_last_uid[namespace] = 0
def get_module_name(cls_or_func):
module_name = cls_or_func.__module__
if module_name == '__main__':
# infer the module name with inspect
for frm in inspect.stack():
module = inspect.getmodule(frm[0])
if module is not None and module.__name__ == '__main__':
# main module found
main_file_path = Path(cast(str, inspect.getsourcefile(frm[0])))
if not Path().samefile(main_file_path.parent):
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem
break
if module_name == '__main__':
warnings.warn('Callstack exhausted but main module still not found. This will probably cause issues that the '
'function/class cannot be imported.')
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if f'{cls_or_func.__module__}.{cls_or_func.__name__}' == 'torch.nn.modules.rnn.LSTM':
module_name = cls_or_func.__module__
return module_name
def get_importable_name(cls, relocate_module=False):
module_name = get_module_name(cls) if relocate_module else cls.__module__
return module_name + '.' + cls.__name__
class NoContextError(Exception):
"""Exception raised when context is missing."""
pass
class ContextStack:
"""
This is to maintain a globally-accessible context environment that is visible to everywhere.
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
get the corresponding value in the namespace.
Note that this is not multi-processing safe. Also, the values will get cleared for a new process.
"""
_stack: Dict[str, List[Any]] = defaultdict(list)
def __init__(self, key: str, value: Any):
self.key = key
self.value = value
def __enter__(self):
self.push(self.key, self.value)
return self
def __exit__(self, *args, **kwargs):
self.pop(self.key)
@classmethod
def push(cls, key: str, value: Any):
cls._stack[key].append(value)
@classmethod
def pop(cls, key: str) -> None:
cls._stack[key].pop()
@classmethod
def top(cls, key: str) -> Any:
if not cls._stack[key]:
raise NoContextError('Context is empty.')
return cls._stack[key][-1]
class ModelNamespace:
"""
To create an individual namespace for models:
1. to enable automatic numbering;
2. to trace general information (like creation of hyper-parameters) of model.
A namespace is bounded to a key. Namespace bounded to different keys are completed isolated.
Namespace can have sub-namespaces (with the same key). The numbering will be chained (e.g., ``model_1_4_2``).
"""
def __init__(self, key: str = _DEFAULT_MODEL_NAMESPACE):
# for example, key: "model_wrapper"
self.key = key
# the "path" of current name
# By default, it's ``[]``
# If a ``@model_wrapper`` is nested inside a model_wrapper, it will become something like ``[1, 3, 2]``.
# See ``__enter__``.
self.name_path: List[int] = []
# parameter specs.
# Currently only used trace calls of ModelParameterChoice.
self.parameter_specs: List[ParameterSpec] = []
def __enter__(self):
# For example, currently the top of stack is [1, 2, 2], and [1, 2, 2, 3] is used,
# the next thing up is [1, 2, 2, 4].
# `reset_uid` to count from zero for "model_wrapper_1_2_2_4"
try:
parent_context: 'ModelNamespace' = ModelNamespace.current_context(self.key)
next_uid = uid(parent_context._simple_name())
self.name_path = parent_context.name_path + [next_uid]
ContextStack.push(self.key, self)
reset_uid(self._simple_name())
except NoContextError:
# not found, no existing namespace
self.name_path = []
ContextStack.push(self.key, self)
reset_uid(self._simple_name())
def __exit__(self, *args, **kwargs):
ContextStack.pop(self.key)
def _simple_name(self) -> str:
return self.key + ''.join(['_' + str(k) for k in self.name_path])
def __repr__(self):
return f'ModelNamespace(name={self._simple_name()}, num_specs={len(self.parameter_specs)})'
# Access the current context in the model #
@staticmethod
def current_context(key: str = _DEFAULT_MODEL_NAMESPACE) -> 'ModelNamespace':
"""Get the current context in key."""
try:
return ContextStack.top(key)
except NoContextError:
raise NoContextError('ModelNamespace context is missing. You might have forgotten to use `@model_wrapper`.')
@staticmethod
def next_label(key: str = _DEFAULT_MODEL_NAMESPACE) -> str:
"""Get the next label for API calls, with automatic numbering."""
try:
current_context = ContextStack.top(key)
except NoContextError:
# fallback to use "default" namespace
# it won't be registered
warnings.warn('ModelNamespace is missing. You might have forgotten to use `@model_wrapper`. '
'Some features might not work. This will be an error in future releases.', RuntimeWarning)
current_context = ModelNamespace('default')
next_uid = uid(current_context._simple_name())
return current_context._simple_name() + '_' + str(next_uid)
def get_current_context(key: str) -> Any:
return ContextStack.top(key)
# map variables to prefix in the state dict
# e.g., {'upsample': 'mynet.module.deconv2.upsample_layer'}
STATE_DICT_PY_MAPPING = '_mapping_'
# map variables to `prefix`.`value` in the state dict
# e.g., {'upsample': 'choice3.upsample_layer'},
# which actually means {'upsample': 'mynet.module.choice3.upsample_layer'},
# and 'upsample' is also in `mynet.module`.
STATE_DICT_PY_MAPPING_PARTIAL = '_mapping_partial_'
@contextmanager
def original_state_dict_hooks(model: Any):
"""
Use this patch if you want to save/load state dict in the original state dict hierarchy.
For example, when you already have a state dict for the base model / search space (which often
happens when you have trained a supernet with one-shot strategies), the state dict isn't organized
in the same way as when a sub-model is sampled from the search space. This patch will help
the modules in the sub-model find the corresponding module in the base model.
The code looks like,
.. code-block:: python
with original_state_dict_hooks(model):
model.load_state_dict(state_dict_from_supernet, strict=False) # supernet has extra keys
Or vice-versa,
.. code-block:: python
with original_state_dict_hooks(model):
supernet_style_state_dict = model.state_dict()
"""
import torch.utils.hooks
import torch.nn as nn
assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.'
# the following are written for pytorch only
# first get the full mapping
full_mapping = {}
def full_mapping_in_module(src_prefix, tar_prefix, module):
if hasattr(module, STATE_DICT_PY_MAPPING):
# only values are complete
local_map = getattr(module, STATE_DICT_PY_MAPPING)
elif hasattr(module, STATE_DICT_PY_MAPPING_PARTIAL):
# keys and values are both incomplete
local_map = getattr(module, STATE_DICT_PY_MAPPING_PARTIAL)
local_map = {k: tar_prefix + v for k, v in local_map.items()}
else:
# no mapping
local_map = {}
if '__self__' in local_map:
# special case, overwrite prefix
tar_prefix = local_map['__self__'] + '.'
for key, value in local_map.items():
if key != '' and key not in module._modules: # not a sub-module, probably a parameter
full_mapping[src_prefix + key] = value
if src_prefix != tar_prefix: # To deal with leaf nodes.
for name, value in itertools.chain(module._parameters.items(), module._buffers.items()): # direct children
if value is None or name in module._non_persistent_buffers_set:
# it won't appear in state dict
continue
if (src_prefix + name) not in full_mapping:
full_mapping[src_prefix + name] = tar_prefix + name
for name, child in module.named_children():
# sub-modules
full_mapping_in_module(
src_prefix + name + '.',
local_map.get(name, tar_prefix + name) + '.', # if mapping doesn't exist, respect the prefix
child
)
full_mapping_in_module('', '', model)
def load_state_dict_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
reverse_mapping = defaultdict(list)
for src, tar in full_mapping.items():
reverse_mapping[tar].append(src)
transf_state_dict = {}
for src, tar_keys in reverse_mapping.items():
if src in state_dict:
value = state_dict.pop(src)
for tar in tar_keys:
transf_state_dict[tar] = value
else:
missing_keys.append(src)
state_dict.update(transf_state_dict)
def state_dict_hook(module, destination, prefix, local_metadata):
result = {}
for src, tar in full_mapping.items():
if src in destination:
result[tar] = destination.pop(src)
else:
raise KeyError(f'"{src}" not in state dict, but found in mapping.')
destination.update(result)
hooks: List[torch.utils.hooks.RemovableHandle] = []
try:
hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook))
hooks.append(model._register_state_dict_hook(state_dict_hook))
yield
finally:
for hook in hooks:
hook.remove()
from nni.nas.utils.misc import *
......@@ -10,8 +10,8 @@
"nni/common/device.py",
"nni/common/graph_utils.py",
"nni/compression",
"nni/nas/tensorflow",
"nni/nas/pytorch",
"nni/nas/execution/pytorch/cgo",
"nni/nas/evaluator/pytorch/cgo",
"nni/retiarii/execution/cgo_engine.py",
"nni/retiarii/execution/logical_optimizer",
"nni/retiarii/evaluator/pytorch/cgo",
......
......@@ -32,6 +32,8 @@ try:
from nni.retiarii.evaluator.pytorch.cgo.evaluator import MultiModelSupervisedLearningModule, _MultiModelSupervisedLearningModule
import nni.retiarii.evaluator.pytorch.cgo.trainer as cgo_trainer
import nni.retiarii.integration_api
module_import_failed = False
except ImportError:
module_import_failed = True
......
......@@ -14,7 +14,7 @@ import nni.runtime.platform.test
import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.retiarii.hub.pytorch as searchspace
from nni.retiarii import fixed_arch
from nni.retiarii.execution.utils import _unpack_if_only_one
from nni.retiarii.execution.utils import unpack_if_only_one
from nni.retiarii.mutator import InvalidMutation, Sampler
from nni.retiarii.nn.pytorch.mutator import extract_mutation_from_pt_module
......@@ -58,7 +58,7 @@ def _test_searchspace_on_dataset(searchspace, dataset='cifar10', arch=None):
if arch is None:
model = try_mutation_until_success(model, mutators, 10)
arch = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
arch = {mut.mutator.label: unpack_if_only_one(mut.samples) for mut in model.history}
print('Selected model:', arch)
with fixed_arch(arch):
......
......@@ -56,7 +56,10 @@ class MockExecutionEngine(AbstractExecutionEngine):
def _reset_execution_engine(engine=None):
nni.retiarii.execution.api._execution_engine = engine
# Use the new NAS reset
# nni.retiarii.execution.api._execution_engine = engine
import nni.nas.execution.api
nni.nas.execution.api._execution_engine = engine
class Net(nn.Module):
......
......@@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import nni.retiarii.nn.pytorch
import nni.nas.nn.pytorch
import torch
......
......@@ -4,6 +4,7 @@ import unittest
from pathlib import Path
import nni.retiarii
import nni.retiarii.integration_api
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.execution import set_execution_engine
......
import json
from pathlib import Path
import sys
from nni.common.framework import get_default_framework, set_default_framework
from nni.retiarii import *
# FIXME
import nni.retiarii.debug_configs
original_framework = nni.retiarii.debug_configs.framework
original_framework = get_default_framework()
max_pool = Operation.new('MaxPool2D', {'pool_size': 2})
avg_pool = Operation.new('AveragePooling2D', {'pool_size': 2})
......@@ -14,11 +12,11 @@ global_pool = Operation.new('GlobalAveragePooling2D')
def setup_module(module):
nni.retiarii.debug_configs.framework = 'tensorflow'
set_default_framework('tensorflow')
def teardown_module(module):
nni.retiarii.debug_configs.framework = original_framework
set_default_framework(original_framework)
class DebugSampler(Sampler):
......
......@@ -15,7 +15,7 @@ from nni.retiarii import InvalidMutation, Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.evaluator import FunctionalEvaluator
from nni.retiarii.execution.utils import _unpack_if_only_one
from nni.retiarii.execution.utils import unpack_if_only_one
from nni.retiarii.experiment.pytorch import preprocess_model
from nni.retiarii.graph import Model
from nni.retiarii.nn.pytorch.api import ValueChoice
......@@ -827,7 +827,7 @@ class Python(GraphIR):
graph_engine = False
def _get_converted_pytorch_model(self, model_ir):
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history}
mutation = {mut.mutator.label: unpack_if_only_one(mut.samples) for mut in model_ir.history}
with ContextStack('fixed', mutation):
model = model_ir.python_class(**model_ir.python_init_params)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment