Unverified Commit 08af7771 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[2.0a2] [retiarii] improvement (#3208)

parent c444e862
...@@ -432,6 +432,7 @@ def _handle_layerchoice(module): ...@@ -432,6 +432,7 @@ def _handle_layerchoice(module):
def _handle_inputchoice(module): def _handle_inputchoice(module):
m_attrs = {} m_attrs = {}
m_attrs['n_candidates'] = module.n_candidates
m_attrs['n_chosen'] = module.n_chosen m_attrs['n_chosen'] = module.n_chosen
m_attrs['reduction'] = module.reduction m_attrs['reduction'] = module.reduction
m_attrs['label'] = module.label m_attrs['label'] = module.label
......
import time import time
import os
from typing import List
from ..graph import Model, ModelStatus from ..graph import Model, ModelStatus
from .base import BaseExecutionEngine from .interface import AbstractExecutionEngine
from .cgo_engine import CGOExecutionEngine
from .interface import AbstractExecutionEngine, WorkerInfo
from .listener import DefaultListener from .listener import DefaultListener
_execution_engine = None _execution_engine = None
_default_listener = None _default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener', __all__ = ['get_execution_engine', 'get_and_register_default_listener',
'submit_models', 'wait_models', 'query_available_resources'] 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec']
def set_execution_engine(engine) -> None:
global _execution_engine
if _execution_engine is None:
_execution_engine = engine
else:
raise RuntimeError('execution engine is already set')
def get_execution_engine() -> BaseExecutionEngine: def get_execution_engine() -> AbstractExecutionEngine:
""" """
Currently we assume the default execution engine is BaseExecutionEngine. Currently we assume the default execution engine is BaseExecutionEngine.
""" """
global _execution_engine global _execution_engine
if _execution_engine is None:
if os.environ.get('CGO') == 'true':
_execution_engine = CGOExecutionEngine()
else:
_execution_engine = BaseExecutionEngine()
return _execution_engine return _execution_engine
...@@ -51,6 +49,11 @@ def wait_models(*models: Model) -> None: ...@@ -51,6 +49,11 @@ def wait_models(*models: Model) -> None:
break break
def query_available_resources() -> List[WorkerInfo]: def query_available_resources() -> int:
listener = get_and_register_default_listener(get_execution_engine()) engine = get_execution_engine()
return listener.resources resources = engine.query_available_resource()
return resources if isinstance(resources, int) else len(resources)
def is_stopped_exec(model: Model) -> bool:
return model.status in (ModelStatus.Trained, ModelStatus.Failed)
import logging import logging
import os
import random
import string
from typing import Dict, Any, List from typing import Dict, Any, List
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData from ..graph import Model, ModelStatus, MetricData
from ..integration import send_trial, receive_trial_parameters, get_advisor from ..integration_api import send_trial, receive_trial_parameters, get_advisor
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -29,7 +32,7 @@ class BaseGraphData: ...@@ -29,7 +32,7 @@ class BaseGraphData:
class BaseExecutionEngine(AbstractExecutionEngine): class BaseExecutionEngine(AbstractExecutionEngine):
""" """
The execution engine with no optimization at all. The execution engine with no optimization at all.
Resource management is yet to be implemented. Resource management is implemented in this class.
""" """
def __init__(self) -> None: def __init__(self) -> None:
...@@ -50,6 +53,8 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -50,6 +53,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self._running_models: Dict[int, Model] = dict() self._running_models: Dict[int, Model] = dict()
self.resources = 0
def submit_models(self, *models: Model) -> None: def submit_models(self, *models: Model) -> None:
for model in models: for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model), data = BaseGraphData(codegen.model_to_pytorch_script(model),
...@@ -60,17 +65,14 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -60,17 +65,14 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self._listeners.append(listener) self._listeners.append(listener)
def _send_trial_callback(self, paramater: dict) -> None: def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners: if self.resources <= 0:
_logger.warning('resources: %s', listener.resources)
if not listener.has_available_resource():
_logger.warning('There is no available resource, but trial is submitted.') _logger.warning('There is no available resource, but trial is submitted.')
listener.on_resource_used(1) self.resources -= 1
_logger.warning('on_resource_used: %s', listener.resources) _logger.info('on_resource_used: %d', self.resources)
def _request_trial_jobs_callback(self, num_trials: int) -> None: def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners: self.resources += num_trials
listener.on_resource_available(1 * num_trials) _logger.info('on_resource_available: %d', self.resources)
_logger.warning('on_resource_available: %s', listener.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None: def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id] model = self._running_models[trial_id]
...@@ -93,8 +95,8 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -93,8 +95,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
for listener in self._listeners: for listener in self._listeners:
listener.on_metric(model, metrics) listener.on_metric(model, metrics)
def query_available_resource(self) -> List[WorkerInfo]: def query_available_resource(self) -> int:
raise NotImplementedError # move the method from listener to here? return self.resources
@classmethod @classmethod
def trial_execute_graph(cls) -> None: def trial_execute_graph(cls) -> None:
...@@ -102,9 +104,12 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -102,9 +104,12 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Initialize the model, hand it over to trainer. Initialize the model, hand it over to trainer.
""" """
graph_data = BaseGraphData.load(receive_trial_parameters()) graph_data = BaseGraphData.load(receive_trial_parameters())
with open('_generated_model.py', 'w') as f: random_str = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6))
file_name = f'_generated_model_{random_str}.py'
with open(file_name, 'w') as f:
f.write(graph_data.model_script) f.write(graph_data.model_script)
trainer_cls = utils.import_(graph_data.training_module) trainer_cls = utils.import_(graph_data.training_module)
model_cls = utils.import_('_generated_model._model') model_cls = utils.import_(f'_generated_model_{random_str}._model')
trainer_instance = trainer_cls(model=model_cls(), **graph_data.training_kwargs) trainer_instance = trainer_cls(model=model_cls(), **graph_data.training_kwargs)
trainer_instance.fit() trainer_instance.fit()
os.remove(file_name)
\ No newline at end of file
...@@ -4,7 +4,7 @@ from typing import List, Dict, Tuple ...@@ -4,7 +4,7 @@ from typing import List, Dict, Tuple
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData from ..graph import Model, ModelStatus, MetricData
from ..integration import send_trial, receive_trial_parameters, get_advisor from ..integration_api import send_trial, receive_trial_parameters, get_advisor
from .logical_optimizer.logical_plan import LogicalPlan, PhysicalDevice from .logical_optimizer.logical_plan import LogicalPlan, PhysicalDevice
from .logical_optimizer.opt_dedup_input import DedupInputOptimizer from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
......
from abc import ABC, abstractmethod, abstractclassmethod from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, NewType, List from typing import Any, NewType, List, Union
from ..graph import Model, MetricData from ..graph import Model, MetricData
...@@ -59,13 +59,6 @@ class AbstractGraphListener(ABC): ...@@ -59,13 +59,6 @@ class AbstractGraphListener(ABC):
""" """
pass pass
@abstractmethod
def on_resource_available(self, resources: List[WorkerInfo]) -> None:
"""
Reports when a worker becomes idle.
"""
pass
class AbstractExecutionEngine(ABC): class AbstractExecutionEngine(ABC):
""" """
...@@ -109,7 +102,7 @@ class AbstractExecutionEngine(ABC): ...@@ -109,7 +102,7 @@ class AbstractExecutionEngine(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def query_available_resource(self) -> List[WorkerInfo]: def query_available_resource(self) -> Union[List[WorkerInfo], int]:
""" """
Returns information of all idle workers. Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers. If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
......
...@@ -3,11 +3,6 @@ from .interface import MetricData, AbstractGraphListener ...@@ -3,11 +3,6 @@ from .interface import MetricData, AbstractGraphListener
class DefaultListener(AbstractGraphListener): class DefaultListener(AbstractGraphListener):
def __init__(self):
self.resources: int = 0 # simply resource count
def has_available_resource(self) -> bool:
return self.resources > 0
def on_metric(self, model: Model, metric: MetricData) -> None: def on_metric(self, model: Model, metric: MetricData) -> None:
model.metric = metric model.metric = metric
...@@ -20,9 +15,3 @@ class DefaultListener(AbstractGraphListener): ...@@ -20,9 +15,3 @@ class DefaultListener(AbstractGraphListener):
model.status = ModelStatus.Trained model.status = ModelStatus.Trained
else: else:
model.status = ModelStatus.Failed model.status = ModelStatus.Failed
def on_resource_available(self, resources: int) -> None:
self.resources += resources
def on_resource_used(self, resources: int) -> None:
self.resources -= resources
import atexit
import logging import logging
import time import socket
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
...@@ -7,10 +8,14 @@ from subprocess import Popen ...@@ -7,10 +8,14 @@ from subprocess import Popen
from threading import Thread from threading import Thread
from typing import Any, Optional from typing import Any, Optional
from ..experiment import Experiment, TrainingServiceConfig, launcher, rest import colorama
import psutil
from ..experiment import Experiment, TrainingServiceConfig, launcher
from ..experiment.config.base import ConfigBase, PathLike from ..experiment.config.base import ConfigBase, PathLike
from ..experiment.config import util from ..experiment.config import util
from ..experiment.pipe import Pipe from ..experiment.pipe import Pipe
from .graph import Model from .graph import Model
from .utils import get_records from .utils import get_records
from .integration import RetiariiAdvisor from .integration import RetiariiAdvisor
...@@ -18,9 +23,11 @@ from .converter import convert_to_graph ...@@ -18,9 +23,11 @@ from .converter import convert_to_graph
from .mutator import Mutator, LayerChoiceMutator, InputChoiceMutator from .mutator import Mutator, LayerChoiceMutator, InputChoiceMutator
from .trainer.interface import BaseTrainer from .trainer.interface import BaseTrainer
from .strategies.strategy import BaseStrategy from .strategies.strategy import BaseStrategy
from .trainer.pytorch import DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
OneShotTrainers = (DartsTrainer, EnasTrainer, ProxylessTrainer, RandomTrainer, SinglePathTrainer)
@dataclass(init=False) @dataclass(init=False)
class RetiariiExeConfig(ConfigBase): class RetiariiExeConfig(ConfigBase):
...@@ -76,7 +83,7 @@ _validation_rules = { ...@@ -76,7 +83,7 @@ _validation_rules = {
class RetiariiExperiment(Experiment): class RetiariiExperiment(Experiment):
def __init__(self, base_model: Model, trainer: BaseTrainer, def __init__(self, base_model: Model, trainer: BaseTrainer,
applied_mutators: Mutator, strategy: BaseStrategy): applied_mutators: Mutator = None, strategy: BaseStrategy = None):
self.config: RetiariiExeConfig = None self.config: RetiariiExeConfig = None
self.port: Optional[int] = None self.port: Optional[int] = None
...@@ -87,6 +94,7 @@ class RetiariiExperiment(Experiment): ...@@ -87,6 +94,7 @@ class RetiariiExperiment(Experiment):
self.recorded_module_args = get_records() self.recorded_module_args = get_records()
self._dispatcher = RetiariiAdvisor() self._dispatcher = RetiariiAdvisor()
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None self._pipe: Optional[Pipe] = None
...@@ -103,7 +111,10 @@ class RetiariiExperiment(Experiment): ...@@ -103,7 +111,10 @@ class RetiariiExperiment(Experiment):
mutator = LayerChoiceMutator(node.name, node.operation.parameters['choices']) mutator = LayerChoiceMutator(node.name, node.operation.parameters['choices'])
applied_mutators.append(mutator) applied_mutators.append(mutator)
for node in ic_nodes: for node in ic_nodes:
mutator = InputChoiceMutator(node.name, node.operation.parameters['n_chosen']) mutator = InputChoiceMutator(node.name,
node.operation.parameters['n_candidates'],
node.operation.parameters['n_chosen'],
node.operation.parameters['reduction'])
applied_mutators.append(mutator) applied_mutators.append(mutator)
return applied_mutators return applied_mutators
...@@ -132,7 +143,7 @@ class RetiariiExperiment(Experiment): ...@@ -132,7 +143,7 @@ class RetiariiExperiment(Experiment):
Thread(target=self.strategy.run, args=(base_model, self.applied_mutators)).start() Thread(target=self.strategy.run, args=(base_model, self.applied_mutators)).start()
_logger.info('Strategy started!') _logger.info('Strategy started!')
def start(self, config: RetiariiExeConfig, port: int = 8080, debug: bool = False) -> None: def start(self, port: int = 8080, debug: bool = False) -> None:
""" """
Start the experiment in background. Start the experiment in background.
This method will raise exception on failure. This method will raise exception on failure.
...@@ -144,11 +155,12 @@ class RetiariiExperiment(Experiment): ...@@ -144,11 +155,12 @@ class RetiariiExperiment(Experiment):
debug debug
Whether to start in debug mode. Whether to start in debug mode.
""" """
# FIXME: atexit.register(self.stop)
if debug: if debug:
logging.getLogger('nni').setLevel(logging.DEBUG) logging.getLogger('nni').setLevel(logging.DEBUG)
self._proc, self._pipe = launcher.start_experiment(config, port, debug) self._proc, self._pipe = launcher.start_experiment(self.config, port, debug)
assert self._proc is not None assert self._proc is not None
assert self._pipe is not None assert self._pipe is not None
...@@ -156,42 +168,42 @@ class RetiariiExperiment(Experiment): ...@@ -156,42 +168,42 @@ class RetiariiExperiment(Experiment):
# dispatcher must be created after pipe initialized # dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api # the logic to launch dispatcher in background should be refactored into dispatcher api
Thread(target=self._dispatcher.run).start() self._dispatcher_thread = Thread(target=self._dispatcher.run)
self._dispatcher_thread.start()
self._start_strategy() self._start_strategy()
ips = [self.config.nni_manager_ip]
for interfaces in psutil.net_if_addrs().values():
for interface in interfaces:
if interface.family == socket.AF_INET:
ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips)
_logger.info(msg)
# TODO: register experiment management metadata # TODO: register experiment management metadata
def stop(self) -> None: def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str:
""" """
Stop background experiment. Run the experiment.
This function will block until experiment finish or error.
""" """
self._proc.kill() if isinstance(self.trainer, OneShotTrainers):
self._pipe.close() self.trainer.fit()
else:
assert config is not None, 'You are using classic search mode, config cannot be None!'
self.config = config
super().run(port, debug)
self.port = None def export_top_models(self, top_n: int):
self._proc = None """
self._pipe = None export several top performing models
"""
raise NotImplementedError
def run(self, config: RetiariiExeConfig, port: int = 8080, debug: bool = False) -> str: def retrain_model(self, model):
""" """
Run the experiment. this function retrains the exported model, and test it to output test accuracy
This function will block until experiment finish or error.
""" """
self.config = config raise NotImplementedError
self.start(config, port, debug)
try:
while True:
time.sleep(10)
status = self.get_status()
# TODO: double check the status
if status in ['ERROR', 'STOPPED', 'NO_MORE_TRIAL']:
return status
finally:
self.stop()
def get_status(self) -> str:
if self.port is None:
raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status')
return resp['status']
import logging import logging
import os
from typing import Any, Callable from typing import Any, Callable
import json_tricks import json_tricks
import nni
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType from nni.utils import MetricType
from .graph import MetricData from .graph import MetricData
from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine
from .execution.api import set_execution_engine
from .integration_api import register_advisor
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -55,6 +59,15 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -55,6 +59,15 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.parameters_count = 0 self.parameters_count = 0
engine = self._create_execution_engine()
set_execution_engine(engine)
def _create_execution_engine(self):
if os.environ.get('CGO') == 'true':
return CGOExecutionEngine()
else:
return BaseExecutionEngine()
def handle_initialize(self, data): def handle_initialize(self, data):
"""callback for initializing the advisor """callback for initializing the advisor
Parameters Parameters
...@@ -126,34 +139,3 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -126,34 +139,3 @@ class RetiariiAdvisor(MsgDispatcherBase):
else: else:
return value return value
return value return value
_advisor: RetiariiAdvisor = None
def get_advisor() -> RetiariiAdvisor:
global _advisor
assert _advisor is not None
return _advisor
def register_advisor(advisor: RetiariiAdvisor):
global _advisor
assert _advisor is None
_advisor = advisor
def send_trial(parameters: dict) -> int:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return get_advisor().send_trial(parameters)
def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
"""
params = nni.get_next_parameter()
return params
from typing import NewType, Any
import nni
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
_advisor: 'RetiariiAdvisor' = None
def get_advisor() -> 'RetiariiAdvisor':
global _advisor
assert _advisor is not None
return _advisor
def register_advisor(advisor: 'RetiariiAdvisor'):
global _advisor
assert _advisor is None
_advisor = advisor
def send_trial(parameters: dict) -> int:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return get_advisor().send_trial(parameters)
def receive_trial_parameters() -> dict:
"""
Received a new trial. Executed on trial end.
"""
params = nni.get_next_parameter()
return params
...@@ -104,6 +104,7 @@ class _RecorderSampler(Sampler): ...@@ -104,6 +104,7 @@ class _RecorderSampler(Sampler):
self.recorded_candidates.append(candidates) self.recorded_candidates.append(candidates)
return candidates[0] return candidates[0]
# the following is for inline mutation # the following is for inline mutation
...@@ -122,14 +123,16 @@ class LayerChoiceMutator(Mutator): ...@@ -122,14 +123,16 @@ class LayerChoiceMutator(Mutator):
class InputChoiceMutator(Mutator): class InputChoiceMutator(Mutator):
def __init__(self, node_name: str, n_chosen: int): def __init__(self, node_name: str, n_candidates: int, n_chosen: int, reduction: str):
super().__init__() super().__init__()
self.node_name = node_name self.node_name = node_name
self.n_candidates = n_candidates
self.n_chosen = n_chosen self.n_chosen = n_chosen
self.reduction = reduction
def mutate(self, model): def mutate(self, model):
target = model.get_node_by_name(self.node_name) target = model.get_node_by_name(self.node_name)
candidates = [i for i in range(self.n_chosen)] candidates = [i for i in range(self.n_candidates)]
chosen = self.choice(candidates) chosen = [self.choice(candidates) for _ in range(self.n_chosen)]
target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs', target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs',
{'chosen': chosen}) {'chosen': chosen, 'reduction': self.reduction})
...@@ -5,10 +5,12 @@ from typing import Any, List ...@@ -5,10 +5,12 @@ from typing import Any, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import add_record from ...utils import add_record, version_larger_equal
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
# NOTE: support pytorch version >= 1.5.0
__all__ = [ __all__ = [
'LayerChoice', 'InputChoice', 'Placeholder', 'LayerChoice', 'InputChoice', 'Placeholder',
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict', 'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
...@@ -29,18 +31,27 @@ __all__ = [ ...@@ -29,18 +31,27 @@ __all__ = [
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold', 'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder', 'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer', 'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d', 'Flatten', 'Hardsigmoid'
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten', 'Hardsigmoid', 'Hardswish'
] ]
if version_larger_equal(torch.__version__, '1.6.0'):
__all__.append('Hardswish')
if version_larger_equal(torch.__version__, '1.7.0'):
__all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss'])
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'ChannelShuffle'
class LayerChoice(nn.Module): class LayerChoice(nn.Module):
def __init__(self, op_candidates, reduction=None, return_mask=False, key=None): def __init__(self, op_candidates, reduction=None, return_mask=False, key=None):
super(LayerChoice, self).__init__() super(LayerChoice, self).__init__()
self.candidate_ops = op_candidates self.candidate_ops = op_candidates
self.label = key self.label = key
self.key = key # deprecated, for backward compatibility
for i, module in enumerate(op_candidates): # deprecated, for backward compatibility
self.add_module(str(i), module)
if reduction or return_mask: if reduction or return_mask:
_logger.warning('input arguments `reduction` and `return_mask` are deprecated!') _logger.warning('input arguments `reduction` and `return_mask` are deprecated!')
...@@ -52,10 +63,12 @@ class InputChoice(nn.Module): ...@@ -52,10 +63,12 @@ class InputChoice(nn.Module):
def __init__(self, n_candidates=None, choose_from=None, n_chosen=1, def __init__(self, n_candidates=None, choose_from=None, n_chosen=1,
reduction="sum", return_mask=False, key=None): reduction="sum", return_mask=False, key=None):
super(InputChoice, self).__init__() super(InputChoice, self).__init__()
self.n_candidates = n_candidates
self.n_chosen = n_chosen self.n_chosen = n_chosen
self.reduction = reduction self.reduction = reduction
self.label = key self.label = key
if n_candidates or choose_from or return_mask: self.key = key # deprecated, for backward compatibility
if choose_from or return_mask:
_logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!') _logger.warning('input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!')
def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor: def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor:
...@@ -86,13 +99,31 @@ class Placeholder(nn.Module): ...@@ -86,13 +99,31 @@ class Placeholder(nn.Module):
class ChosenInputs(nn.Module): class ChosenInputs(nn.Module):
def __init__(self, chosen: int): """
"""
def __init__(self, chosen: List[int], reduction: str):
super().__init__() super().__init__()
self.chosen = chosen self.chosen = chosen
self.reduction = reduction
def forward(self, candidate_inputs): def forward(self, candidate_inputs):
# TODO: support multiple chosen inputs return self._tensor_reduction(self.reduction, [candidate_inputs[i] for i in self.chosen])
return candidate_inputs[self.chosen]
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
# the following are pytorch modules # the following are pytorch modules
...@@ -132,7 +163,6 @@ def wrap_module(original_class): ...@@ -132,7 +163,6 @@ def wrap_module(original_class):
return original_class return original_class
# TODO: support different versions of pytorch
Identity = wrap_module(nn.Identity) Identity = wrap_module(nn.Identity)
Linear = wrap_module(nn.Linear) Linear = wrap_module(nn.Linear)
Conv1d = wrap_module(nn.Conv1d) Conv1d = wrap_module(nn.Conv1d)
...@@ -236,6 +266,17 @@ TransformerDecoder = wrap_module(nn.TransformerDecoder) ...@@ -236,6 +266,17 @@ TransformerDecoder = wrap_module(nn.TransformerDecoder)
TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer) TransformerEncoderLayer = wrap_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer) TransformerDecoderLayer = wrap_module(nn.TransformerDecoderLayer)
Transformer = wrap_module(nn.Transformer) Transformer = wrap_module(nn.Transformer)
Flatten = wrap_module(nn.Flatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = wrap_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = wrap_module(nn.SiLU)
Unflatten = wrap_module(nn.Unflatten)
TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
#LazyLinear = wrap_module(nn.LazyLinear) #LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d) #LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d) #LazyConv2d = wrap_module(nn.LazyConv2d)
...@@ -243,10 +284,4 @@ Transformer = wrap_module(nn.Transformer) ...@@ -243,10 +284,4 @@ Transformer = wrap_module(nn.Transformer)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d) #LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d) #LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d) #LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
Flatten = wrap_module(nn.Flatten)
#Unflatten = wrap_module(nn.Unflatten)
Hardsigmoid = wrap_module(nn.Hardsigmoid)
Hardswish = wrap_module(nn.Hardswish)
#SiLU = wrap_module(nn.SiLU)
#TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
#ChannelShuffle = wrap_module(nn.ChannelShuffle) #ChannelShuffle = wrap_module(nn.ChannelShuffle)
\ No newline at end of file
from .tpe_strategy import TPEStrategy from .tpe_strategy import TPEStrategy
from .random_strategy import RandomStrategy
import logging
import random
import time
from .. import Sampler, submit_models, query_available_resources
from .strategy import BaseStrategy
_logger = logging.getLogger(__name__)
class RandomSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return random.choice(candidates)
class RandomStrategy(BaseStrategy):
def __init__(self):
self.random_sampler = RandomSampler()
def run(self, base_model, applied_mutators):
_logger.info('stargety start...')
while True:
avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model
_logger.info('apply mutators...')
_logger.info('mutators: %s', str(applied_mutators))
for mutator in applied_mutators:
mutator.bind_sampler(self.random_sampler)
model = mutator.apply(model)
# run models
submit_models(model)
else:
time.sleep(2)
import logging import logging
import time
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
from .. import Sampler, submit_models, wait_models from .. import Sampler, submit_models, query_available_resources, is_stopped_exec
from .strategy import BaseStrategy from .strategy import BaseStrategy
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -39,6 +40,7 @@ class TPEStrategy(BaseStrategy): ...@@ -39,6 +40,7 @@ class TPEStrategy(BaseStrategy):
def __init__(self): def __init__(self):
self.tpe_sampler = TPESampler() self.tpe_sampler = TPESampler()
self.model_id = 0 self.model_id = 0
self.running_models = {}
def run(self, base_model, applied_mutators): def run(self, base_model, applied_mutators):
sample_space = [] sample_space = []
...@@ -48,9 +50,10 @@ class TPEStrategy(BaseStrategy): ...@@ -48,9 +50,10 @@ class TPEStrategy(BaseStrategy):
sample_space.extend(recorded_candidates) sample_space.extend(recorded_candidates)
self.tpe_sampler.update_sample_space(sample_space) self.tpe_sampler.update_sample_space(sample_space)
try:
_logger.info('stargety start...') _logger.info('stargety start...')
while True: while True:
avail_resource = query_available_resources()
if avail_resource > 0:
model = base_model model = base_model
_logger.info('apply mutators...') _logger.info('apply mutators...')
_logger.info('mutators: %s', str(applied_mutators)) _logger.info('mutators: %s', str(applied_mutators))
...@@ -61,9 +64,18 @@ class TPEStrategy(BaseStrategy): ...@@ -61,9 +64,18 @@ class TPEStrategy(BaseStrategy):
model = mutator.apply(model) model = mutator.apply(model)
# run models # run models
submit_models(model) submit_models(model)
wait_models(model) self.running_models[self.model_id] = model
self.tpe_sampler.receive_result(self.model_id, model.metric)
self.model_id += 1 self.model_id += 1
_logger.info('Strategy says: %s', model.metric) else:
except Exception: time.sleep(2)
_logger.error(logging.exception('message'))
_logger.warning('num of running models: %d', len(self.running_models))
to_be_deleted = []
for _id, _model in self.running_models.items():
if is_stopped_exec(_model):
if _model.metric is not None:
self.tpe_sampler.receive_result(_id, _model.metric)
_logger.warning('tpe receive results: %d, %s', _id, _model.metric)
to_be_deleted.append(_id)
for _id in to_be_deleted:
del self.running_models[_id]
...@@ -6,6 +6,7 @@ from collections import OrderedDict ...@@ -6,6 +6,7 @@ from collections import OrderedDict
import numpy as np import numpy as np
import torch import torch
import nni.retiarii.nn.pytorch as nn
from nni.nas.pytorch.mutables import InputChoice, LayerChoice from nni.nas.pytorch.mutables import InputChoice, LayerChoice
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -157,7 +158,7 @@ def replace_layer_choice(root_module, init_fn, modules=None): ...@@ -157,7 +158,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
List[Tuple[str, nn.Module]] List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules. A list from layer choice keys (names) and replaced modules.
""" """
return _replace_module_with_type(root_module, init_fn, LayerChoice, modules) return _replace_module_with_type(root_module, init_fn, (LayerChoice, nn.LayerChoice), modules)
def replace_input_choice(root_module, init_fn, modules=None): def replace_input_choice(root_module, init_fn, modules=None):
...@@ -178,4 +179,4 @@ def replace_input_choice(root_module, init_fn, modules=None): ...@@ -178,4 +179,4 @@ def replace_input_choice(root_module, init_fn, modules=None):
List[Tuple[str, nn.Module]] List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules. A list from layer choice keys (names) and replaced modules.
""" """
return _replace_module_with_type(root_module, init_fn, InputChoice, modules) return _replace_module_with_type(root_module, init_fn, (InputChoice, nn.InputChoice), modules)
...@@ -10,6 +10,11 @@ def import_(target: str, allow_none: bool = False) -> Any: ...@@ -10,6 +10,11 @@ def import_(target: str, allow_none: bool = False) -> Any:
module = __import__(path, globals(), locals(), [identifier]) module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier) return getattr(module, identifier)
def version_larger_equal(a: str, b: str) -> bool:
# TODO: refactor later
a = a.split('+')[0]
b = b.split('+')[0]
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
_records = {} _records = {}
...@@ -24,7 +29,7 @@ def add_record(key, value): ...@@ -24,7 +29,7 @@ def add_record(key, value):
""" """
global _records global _records
if _records is not None: if _records is not None:
assert key not in _records, '{} already in _records'.format(key) #assert key not in _records, '{} already in _records'.format(key)
_records[key] = value _records[key] = value
......
...@@ -55,7 +55,7 @@ class Node(nn.Module): ...@@ -55,7 +55,7 @@ class Node(nn.Module):
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False) ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
])) ]))
self.drop_path = ops.DropPath() self.drop_path = ops.DropPath()
self.input_switch = nn.InputChoice(n_chosen=2) self.input_switch = nn.InputChoice(n_candidates=num_prev_nodes, n_chosen=2)
def forward(self, prev_nodes: List['Tensor']) -> 'Tensor': def forward(self, prev_nodes: List['Tensor']) -> 'Tensor':
#assert self.ops.__len__() == len(prev_nodes) #assert self.ops.__len__() == len(prev_nodes)
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from pathlib import Path from pathlib import Path
from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategies import TPEStrategy from nni.retiarii.strategies import TPEStrategy, RandomStrategy
from nni.retiarii.trainer import PyTorchImageClassificationTrainer from nni.retiarii.trainer import PyTorchImageClassificationTrainer
from darts_model import CNN from darts_model import CNN
...@@ -18,7 +18,8 @@ if __name__ == '__main__': ...@@ -18,7 +18,8 @@ if __name__ == '__main__':
optimizer_kwargs={"lr": 1e-3}, optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1}) trainer_kwargs={"max_epochs": 1})
simple_startegy = TPEStrategy() #simple_startegy = TPEStrategy()
simple_startegy = RandomStrategy()
exp = RetiariiExperiment(base_model, trainer, [], simple_startegy) exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)
......
import json
import numpy as np
import os
import sys
import torch
import torch.nn as nn
from pathlib import Path
from torchvision import transforms
from torchvision.datasets import CIFAR10
from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategies import TPEStrategy
from nni.retiarii.trainer.pytorch import DartsTrainer
from darts_model import CNN
class Cutout(object):
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def get_dataset(cls, cutout_length=0):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]
cutout = []
if cutout_length > 0:
cutout.append(Cutout(cutout_length))
train_transform = transforms.Compose(transf + normalize + cutout)
valid_transform = transforms.Compose(normalize)
if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8)
dataset_train, dataset_valid = get_dataset("cifar10")
criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(base_model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 50, eta_min=0.001)
trainer = DartsTrainer(
model=base_model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optim,
num_epochs=50,
dataset=dataset_train,
batch_size=32,
log_frequency=10,
unrolled=False
)
exp = RetiariiExperiment(base_model, trainer)
exp.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment