"mmdet3d/datasets/vscode:/vscode.git/clone" did not exist on "bf4396ec1c3c0a78730bdcc850e346f261d4f1ba"
Unverified Commit 4784cc6c authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge pull request #3302 from microsoft/v2.0-merge

Merge branch v2.0 into master (no squash)
parents 25db55ca 349ead41
from .interface import BaseTrainer from .interface import BaseTrainer, BaseOneShotTrainer
from .pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer from .pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
...@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any: ...@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any:
return None return None
@register_trainer() @register_trainer
class PyTorchImageClassificationTrainer(BaseTrainer): class PyTorchImageClassificationTrainer(BaseTrainer):
""" """
Image classification trainer for PyTorch. Image classification trainer for PyTorch.
...@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer): ...@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently, Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful. only the key ``max_epochs`` is useful.
""" """
super(PyTorchImageClassificationTrainer, self).__init__() super().__init__()
self._use_cuda = torch.cuda.is_available() self._use_cuda = torch.cuda.is_available()
self.model = model self.model = model
if self._use_cuda: if self._use_cuda:
......
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from ..interface import BaseOneShotTrainer from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
...@@ -17,13 +18,14 @@ _logger = logging.getLogger(__name__) ...@@ -17,13 +18,14 @@ _logger = logging.getLogger(__name__)
class DartsLayerChoice(nn.Module): class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice): def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__() super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key
self.op_choices = nn.ModuleDict(layer_choice.named_children()) self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3) self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()]) op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * self.alpha.view(*alpha_shape), 0) return torch.sum(op_results * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self): def parameters(self):
for _, p in self.named_parameters(): for _, p in self.named_parameters():
...@@ -42,13 +44,14 @@ class DartsLayerChoice(nn.Module): ...@@ -42,13 +44,14 @@ class DartsLayerChoice(nn.Module):
class DartsInputChoice(nn.Module): class DartsInputChoice(nn.Module):
def __init__(self, input_choice): def __init__(self, input_choice):
super(DartsInputChoice, self).__init__() super(DartsInputChoice, self).__init__()
self.name = input_choice.key
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3) self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1 self.n_chosen = input_choice.n_chosen or 1
def forward(self, inputs): def forward(self, inputs):
inputs = torch.stack(inputs) inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1) alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * self.alpha.view(*alpha_shape), 0) return torch.sum(inputs * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self): def parameters(self):
for _, p in self.named_parameters(): for _, p in self.named_parameters():
...@@ -123,7 +126,15 @@ class DartsTrainer(BaseOneShotTrainer): ...@@ -123,7 +126,15 @@ class DartsTrainer(BaseOneShotTrainer):
module.to(self.device) module.to(self.device)
self.model_optim = optimizer self.model_optim = optimizer
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate, betas=(0.5, 0.999), # use the same architecture weight for modules with duplicated names
ctrl_params = {}
for _, m in self.nas_modules:
if m.name in ctrl_params:
assert m.alpha.size() == ctrl_params[m.name].size(), 'Size of parameters with the same label should be same.'
m.alpha = ctrl_params[m.name]
else:
ctrl_params[m.name] = m.alpha
self.ctrl_optim = torch.optim.Adam(list(ctrl_params.values()), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3) weight_decay=1.0E-3)
self.unrolled = unrolled self.unrolled = unrolled
self.grad_clip = 5. self.grad_clip = 5.
......
...@@ -157,6 +157,7 @@ class ProxylessTrainer(BaseOneShotTrainer): ...@@ -157,6 +157,7 @@ class ProxylessTrainer(BaseOneShotTrainer):
module.to(self.device) module.to(self.device)
self.optimizer = optimizer self.optimizer = optimizer
# we do not support deduplicate control parameters with same label (like DARTS) yet.
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate, self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate,
weight_decay=0, betas=(0, 0.999), eps=1e-8) weight_decay=0, betas=(0, 0.999), eps=1e-8)
self._init_dataloader() self._init_dataloader()
......
...@@ -6,6 +6,7 @@ from collections import OrderedDict ...@@ -6,6 +6,7 @@ from collections import OrderedDict
import numpy as np import numpy as np
import torch import torch
import nni.retiarii.nn.pytorch as nn
from nni.nas.pytorch.mutables import InputChoice, LayerChoice from nni.nas.pytorch.mutables import InputChoice, LayerChoice
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -157,7 +158,7 @@ def replace_layer_choice(root_module, init_fn, modules=None): ...@@ -157,7 +158,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
List[Tuple[str, nn.Module]] List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules. A list from layer choice keys (names) and replaced modules.
""" """
return _replace_module_with_type(root_module, init_fn, LayerChoice, modules) return _replace_module_with_type(root_module, init_fn, (LayerChoice, nn.LayerChoice), modules)
def replace_input_choice(root_module, init_fn, modules=None): def replace_input_choice(root_module, init_fn, modules=None):
...@@ -178,4 +179,4 @@ def replace_input_choice(root_module, init_fn, modules=None): ...@@ -178,4 +179,4 @@ def replace_input_choice(root_module, init_fn, modules=None):
List[Tuple[str, nn.Module]] List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules. A list from layer choice keys (names) and replaced modules.
""" """
return _replace_module_with_type(root_module, init_fn, InputChoice, modules) return _replace_module_with_type(root_module, init_fn, (InputChoice, nn.InputChoice), modules)
import inspect import inspect
import warnings
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
...@@ -11,6 +12,13 @@ def import_(target: str, allow_none: bool = False) -> Any: ...@@ -11,6 +12,13 @@ def import_(target: str, allow_none: bool = False) -> Any:
return getattr(module, identifier) return getattr(module, identifier)
def version_larger_equal(a: str, b: str) -> bool:
# TODO: refactor later
a = a.split('+')[0]
b = b.split('+')[0]
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
_records = {} _records = {}
...@@ -19,6 +27,11 @@ def get_records(): ...@@ -19,6 +27,11 @@ def get_records():
return _records return _records
def clear_records():
global _records
_records = {}
def add_record(key, value): def add_record(key, value):
""" """
""" """
...@@ -28,69 +41,83 @@ def add_record(key, value): ...@@ -28,69 +41,83 @@ def add_record(key, value):
_records[key] = value _records[key] = value
def _register_module(original_class): def del_record(key):
orig_init = original_class.__init__ global _records
argname_list = list(inspect.signature(original_class).parameters.keys()) if _records is not None:
# Make copy of original __init__, so we can call it without recursion _records.pop(key, None)
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__ def _blackbox_cls(cls, module_name, register_format=None):
class wrapper(cls):
def __init__(self, *args, **kwargs):
argname_list = list(inspect.signature(cls).parameters.keys())
full_args = {}
full_args.update(kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
return original_class for argname, value in zip(argname_list, args):
full_args[argname] = value
# eject un-serializable arguments
for k in list(full_args.keys()):
# The list is not complete and does not support nested cases.
if not isinstance(full_args[k], (int, float, str, dict, list, tuple)):
if not (register_format == 'full' and k == 'model'):
# no warning if it is base model in trainer
warnings.warn(f'{cls} has un-serializable arguments {k} whose value is {full_args[k]}. \
This is not supported. You can ignore this warning if you are passing the model to trainer.')
full_args.pop(k)
def register_module(): if register_format == 'args':
""" add_record(id(self), full_args)
Register a module. elif register_format == 'full':
""" full_class_name = cls.__module__ + '.' + cls.__name__
# use it as a decorator: @register_module() add_record(id(self), {'modulename': full_class_name, 'args': full_args})
def _register(cls):
m = _register_module( super().__init__(*args, **kwargs)
original_class=cls)
return m
return _register def __del__(self):
del_record(id(self))
# using module_name instead of cls.__module__ because it's more natural to see where the module gets wrapped
# instead of simply putting torch.nn or etc.
wrapper.__module__ = module_name
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
def _register_trainer(original_class): return wrapper
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
full_class_name = original_class.__module__ + '.' + original_class.__name__
def __init__(self, *args, **kws): def blackbox(cls, *args, **kwargs):
full_args = {} """
full_args.update(kws) To create an blackbox instance inline without decorator. For example,
for i, arg in enumerate(args):
# TODO: support both pytorch and tensorflow
from .nn.pytorch import Module
if isinstance(args[i], Module):
# ignore the base model object
continue
full_args[argname_list[i]] = arg
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
orig_init(self, *args, **kws) # Call the original __init__ .. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
# get caller module name
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')(*args, **kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def blackbox_module(cls):
"""
Register a module. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')
def register_trainer():
def _register(cls):
m = _register_trainer(
original_class=cls)
return m
return _register def register_trainer(cls):
"""
Register a trainer. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'full')
_last_uid = defaultdict(int) _last_uid = defaultdict(int)
......
...@@ -12,6 +12,13 @@ import colorama ...@@ -12,6 +12,13 @@ import colorama
from .env_vars import dispatcher_env_vars, trial_env_vars from .env_vars import dispatcher_env_vars, trial_env_vars
handlers = {}
log_format = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
time_format = '%Y-%m-%d %H:%M:%S'
formatter = Formatter(log_format, time_format)
def init_logger() -> None: def init_logger() -> None:
""" """
This function will (and should only) get invoked on the first time of importing nni (no matter which submodule). This function will (and should only) get invoked on the first time of importing nni (no matter which submodule).
...@@ -37,6 +44,8 @@ def init_logger() -> None: ...@@ -37,6 +44,8 @@ def init_logger() -> None:
_init_logger_standalone() _init_logger_standalone()
logging.getLogger('filelock').setLevel(logging.WARNING)
def init_logger_experiment() -> None: def init_logger_experiment() -> None:
""" """
...@@ -44,15 +53,19 @@ def init_logger_experiment() -> None: ...@@ -44,15 +53,19 @@ def init_logger_experiment() -> None:
This function will get invoked after `init_logger()`. This function will get invoked after `init_logger()`.
""" """
formatter.format = _colorful_format colorful_formatter = Formatter(log_format, time_format)
colorful_formatter.format = _colorful_format
handlers['_default_'].setFormatter(colorful_formatter)
def start_experiment_log(experiment_id: str, log_directory: Path, debug: bool) -> None:
log_path = _prepare_log_dir(log_directory) / 'dispatcher.log'
log_level = logging.DEBUG if debug else logging.INFO
_register_handler(FileHandler(log_path), log_level, experiment_id)
time_format = '%Y-%m-%d %H:%M:%S' def stop_experiment_log(experiment_id: str) -> None:
if experiment_id in handlers:
logging.getLogger().removeHandler(handlers.pop(experiment_id))
formatter = Formatter(
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s',
time_format
)
def _init_logger_dispatcher() -> None: def _init_logger_dispatcher() -> None:
log_level_map = { log_level_map = {
...@@ -66,26 +79,20 @@ def _init_logger_dispatcher() -> None: ...@@ -66,26 +79,20 @@ def _init_logger_dispatcher() -> None:
log_path = _prepare_log_dir(dispatcher_env_vars.NNI_LOG_DIRECTORY) / 'dispatcher.log' log_path = _prepare_log_dir(dispatcher_env_vars.NNI_LOG_DIRECTORY) / 'dispatcher.log'
log_level = log_level_map.get(dispatcher_env_vars.NNI_LOG_LEVEL, logging.INFO) log_level = log_level_map.get(dispatcher_env_vars.NNI_LOG_LEVEL, logging.INFO)
_setup_root_logger(FileHandler(log_path), log_level) _register_handler(FileHandler(log_path), log_level)
def _init_logger_trial() -> None: def _init_logger_trial() -> None:
log_path = _prepare_log_dir(trial_env_vars.NNI_OUTPUT_DIR) / 'trial.log' log_path = _prepare_log_dir(trial_env_vars.NNI_OUTPUT_DIR) / 'trial.log'
log_file = open(log_path, 'w') log_file = open(log_path, 'w')
_setup_root_logger(StreamHandler(log_file), logging.INFO) _register_handler(StreamHandler(log_file), logging.INFO)
if trial_env_vars.NNI_PLATFORM == 'local': if trial_env_vars.NNI_PLATFORM == 'local':
sys.stdout = _LogFileWrapper(log_file) sys.stdout = _LogFileWrapper(log_file)
def _init_logger_standalone() -> None: def _init_logger_standalone() -> None:
_setup_nni_logger(StreamHandler(sys.stdout), logging.INFO) _register_handler(StreamHandler(sys.stdout), logging.INFO)
# Following line does not affect NNI loggers, but without this user's logger won't
# print log even it's level is set to INFO, so we do it for user's convenience.
# If this causes any issue in future, remove it and use `logging.info()` instead of
# `logging.getLogger('xxx').info()` in all examples.
logging.basicConfig()
def _prepare_log_dir(path: Optional[str]) -> Path: def _prepare_log_dir(path: Optional[str]) -> Path:
...@@ -95,20 +102,18 @@ def _prepare_log_dir(path: Optional[str]) -> Path: ...@@ -95,20 +102,18 @@ def _prepare_log_dir(path: Optional[str]) -> Path:
ret.mkdir(parents=True, exist_ok=True) ret.mkdir(parents=True, exist_ok=True)
return ret return ret
def _setup_root_logger(handler: Handler, level: int) -> None: def _register_handler(handler: Handler, level: int, tag: str = '_default_') -> None:
_setup_logger('', handler, level) assert tag not in handlers
handlers[tag] = handler
def _setup_nni_logger(handler: Handler, level: int) -> None:
_setup_logger('nni', handler, level)
def _setup_logger(name: str, handler: Handler, level: int) -> None:
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger = logging.getLogger(name) logger = logging.getLogger()
logger.addHandler(handler) logger.addHandler(handler)
logger.setLevel(level) logger.setLevel(level)
logger.propagate = False
def _colorful_format(record): def _colorful_format(record):
time = formatter.formatTime(record, time_format)
if not record.name.startswith('nni.'):
return '[{}] ({}) {}'.format(time, record.name, record.msg % record.args)
if record.levelno >= logging.ERROR: if record.levelno >= logging.ERROR:
color = colorama.Fore.RED color = colorama.Fore.RED
elif record.levelno >= logging.WARNING: elif record.levelno >= logging.WARNING:
...@@ -118,7 +123,6 @@ def _colorful_format(record): ...@@ -118,7 +123,6 @@ def _colorful_format(record):
else: else:
color = colorama.Fore.BLUE color = colorama.Fore.BLUE
msg = color + (record.msg % record.args) + colorama.Style.RESET_ALL msg = color + (record.msg % record.args) + colorama.Style.RESET_ALL
time = formatter.formatTime(record, time_format)
if record.levelno < logging.INFO: if record.levelno < logging.INFO:
return '[{}] {}:{} {}'.format(time, record.threadName, record.name, msg) return '[{}] {}:{} {}'.format(time, record.threadName, record.name, msg)
else: else:
......
...@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None: ...@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from .standalone import * from .standalone import *
elif trial_env_vars.NNI_PLATFORM == 'unittest': elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import * from .test import *
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'): elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'hybrid'):
from .local import * from .local import *
else: else:
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM) raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
...@@ -124,7 +124,7 @@ common_schema = { ...@@ -124,7 +124,7 @@ common_schema = {
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')), Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999), Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice( 'trainingServicePlatform': setChoice(
'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'), 'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'hybrid'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'), Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool), Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool), Optional('multiThread'): setType('multiThread', bool),
...@@ -262,7 +262,7 @@ aml_config_schema = { ...@@ -262,7 +262,7 @@ aml_config_schema = {
} }
} }
heterogeneous_trial_schema = { hybrid_trial_schema = {
'trial': { 'trial': {
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
Optional('nniManagerNFSMountPath'): setPathCheck('nniManagerNFSMountPath'), Optional('nniManagerNFSMountPath'): setPathCheck('nniManagerNFSMountPath'),
...@@ -279,8 +279,8 @@ heterogeneous_trial_schema = { ...@@ -279,8 +279,8 @@ heterogeneous_trial_schema = {
} }
} }
heterogeneous_config_schema = { hybrid_config_schema = {
'heterogeneousConfig': { 'hybridConfig': {
'trainingServicePlatforms': ['local', 'remote', 'pai', 'aml'] 'trainingServicePlatforms': ['local', 'remote', 'pai', 'aml']
} }
} }
...@@ -461,7 +461,7 @@ training_service_schema_dict = { ...@@ -461,7 +461,7 @@ training_service_schema_dict = {
'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}), 'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}),
'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}), 'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}),
'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}), 'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}),
'heterogeneous': Schema({**common_schema, **heterogeneous_trial_schema, **heterogeneous_config_schema, **machine_list_schema, 'hybrid': Schema({**common_schema, **hybrid_trial_schema, **hybrid_config_schema, **machine_list_schema,
**pai_config_schema, **aml_config_schema, **remote_config_schema}), **pai_config_schema, **aml_config_schema, **remote_config_schema}),
} }
...@@ -479,7 +479,7 @@ class NNIConfigSchema: ...@@ -479,7 +479,7 @@ class NNIConfigSchema:
self.validate_pai_trial_conifg(experiment_config) self.validate_pai_trial_conifg(experiment_config)
self.validate_kubeflow_operators(experiment_config) self.validate_kubeflow_operators(experiment_config)
self.validate_eth0_device(experiment_config) self.validate_eth0_device(experiment_config)
self.validate_heterogeneous_platforms(experiment_config) self.validate_hybrid_platforms(experiment_config)
def validate_tuner_adivosr_assessor(self, experiment_config): def validate_tuner_adivosr_assessor(self, experiment_config):
if experiment_config.get('advisor'): if experiment_config.get('advisor'):
...@@ -590,15 +590,15 @@ class NNIConfigSchema: ...@@ -590,15 +590,15 @@ class NNIConfigSchema:
and 'eth0' not in netifaces.interfaces(): and 'eth0' not in netifaces.interfaces():
raise SchemaError('This machine does not contain eth0 network device, please set nniManagerIp in config file!') raise SchemaError('This machine does not contain eth0 network device, please set nniManagerIp in config file!')
def validate_heterogeneous_platforms(self, experiment_config): def validate_hybrid_platforms(self, experiment_config):
required_config_name_map = { required_config_name_map = {
'remote': 'machineList', 'remote': 'machineList',
'aml': 'amlConfig', 'aml': 'amlConfig',
'pai': 'paiConfig' 'pai': 'paiConfig'
} }
if experiment_config.get('trainingServicePlatform') == 'heterogeneous': if experiment_config.get('trainingServicePlatform') == 'hybrid':
for platform in experiment_config['heterogeneousConfig']['trainingServicePlatforms']: for platform in experiment_config['hybridConfig']['trainingServicePlatforms']:
config_name = required_config_name_map.get(platform) config_name = required_config_name_map.get(platform)
if config_name and not experiment_config.get(config_name): if config_name and not experiment_config.get(config_name):
raise SchemaError('Need to set {0} for {1} in heterogeneous mode!'.format(config_name, platform)) raise SchemaError('Need to set {0} for {1} in hybrid mode!'.format(config_name, platform))
\ No newline at end of file
...@@ -17,7 +17,7 @@ from .launcher_utils import validate_all_content ...@@ -17,7 +17,7 @@ from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_response from .rest_utils import rest_put, rest_post, check_rest_server, check_response
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls from .url_utils import cluster_metadata_url, experiment_url, get_local_urls
from .config_utils import Config, Experiments from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, \ from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \
detect_port, get_user detect_port, get_user
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
...@@ -47,10 +47,10 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log ...@@ -47,10 +47,10 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
'You could use \'nnictl create --help\' to get help information' % port) 'You could use \'nnictl create --help\' to get help information' % port)
exit(1) exit(1)
if (platform != 'local') and detect_port(int(port) + 1): if (platform not in ['local', 'aml']) and detect_port(int(port) + 1):
print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \ print_error('%s mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
'You could set another port to start experiment!\n' \ 'You could set another port to start experiment!\n' \
'You could use \'nnictl create --help\' to get help information' % ((int(port) + 1), (int(port) + 1))) 'You could use \'nnictl create --help\' to get help information' % (platform, (int(port) + 1), (int(port) + 1)))
exit(1) exit(1)
print_normal('Starting restful server...') print_normal('Starting restful server...')
...@@ -300,23 +300,25 @@ def set_aml_config(experiment_config, port, config_file_name): ...@@ -300,23 +300,25 @@ def set_aml_config(experiment_config, port, config_file_name):
#set trial_config #set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message return set_trial_config(experiment_config, port, config_file_name), err_message
def set_heterogeneous_config(experiment_config, port, config_file_name): def set_hybrid_config(experiment_config, port, config_file_name):
'''set heterogeneous configuration''' '''set hybrid configuration'''
heterogeneous_config_data = dict() hybrid_config_data = dict()
heterogeneous_config_data['heterogeneous_config'] = experiment_config['heterogeneousConfig'] hybrid_config_data['hybrid_config'] = experiment_config['hybridConfig']
platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms'] platform_list = experiment_config['hybridConfig']['trainingServicePlatforms']
for platform in platform_list: for platform in platform_list:
if platform == 'aml': if platform == 'aml':
heterogeneous_config_data['aml_config'] = experiment_config['amlConfig'] hybrid_config_data['aml_config'] = experiment_config['amlConfig']
elif platform == 'remote': elif platform == 'remote':
if experiment_config.get('remoteConfig'): if experiment_config.get('remoteConfig'):
heterogeneous_config_data['remote_config'] = experiment_config['remoteConfig'] hybrid_config_data['remote_config'] = experiment_config['remoteConfig']
heterogeneous_config_data['machine_list'] = experiment_config['machineList'] hybrid_config_data['machine_list'] = experiment_config['machineList']
elif platform == 'local' and experiment_config.get('localConfig'): elif platform == 'local' and experiment_config.get('localConfig'):
heterogeneous_config_data['local_config'] = experiment_config['localConfig'] hybrid_config_data['local_config'] = experiment_config['localConfig']
elif platform == 'pai': elif platform == 'pai':
heterogeneous_config_data['pai_config'] = experiment_config['paiConfig'] hybrid_config_data['pai_config'] = experiment_config['paiConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(heterogeneous_config_data), REST_TIME_OUT) # It needs to connect all remote machines, set longer timeout here to wait for restful server connection response.
time_out = 60 if 'remote' in platform_list else REST_TIME_OUT
response = rest_put(cluster_metadata_url(port), json.dumps(hybrid_config_data), time_out)
err_message = None err_message = None
if not response or not response.status_code == 200: if not response or not response.status_code == 200:
if response is not None: if response is not None:
...@@ -412,10 +414,10 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -412,10 +414,10 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{'key': 'aml_config', 'value': experiment_config['amlConfig']}) {'key': 'aml_config', 'value': experiment_config['amlConfig']})
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']}) {'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'heterogeneous': elif experiment_config['trainingServicePlatform'] == 'hybrid':
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key': 'heterogeneous_config', 'value': experiment_config['heterogeneousConfig']}) {'key': 'hybrid_config', 'value': experiment_config['hybridConfig']})
platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms'] platform_list = experiment_config['hybridConfig']['trainingServicePlatforms']
request_dict = { request_dict = {
'aml': {'key': 'aml_config', 'value': experiment_config.get('amlConfig')}, 'aml': {'key': 'aml_config', 'value': experiment_config.get('amlConfig')},
'remote': {'key': 'machine_list', 'value': experiment_config.get('machineList')}, 'remote': {'key': 'machine_list', 'value': experiment_config.get('machineList')},
...@@ -460,8 +462,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res ...@@ -460,8 +462,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result, err_msg = set_dlts_config(experiment_config, port, config_file_name) config_result, err_msg = set_dlts_config(experiment_config, port, config_file_name)
elif platform == 'aml': elif platform == 'aml':
config_result, err_msg = set_aml_config(experiment_config, port, config_file_name) config_result, err_msg = set_aml_config(experiment_config, port, config_file_name)
elif platform == 'heterogeneous': elif platform == 'hybrid':
config_result, err_msg = set_heterogeneous_config(experiment_config, port, config_file_name) config_result, err_msg = set_hybrid_config(experiment_config, port, config_file_name)
else: else:
raise Exception(ERROR_INFO % 'Unsupported platform!') raise Exception(ERROR_INFO % 'Unsupported platform!')
exit(1) exit(1)
...@@ -509,6 +511,11 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -509,6 +511,11 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \ rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, experiment_id, foreground, log_dir, log_level) mode, experiment_id, foreground, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid) nni_config.set_config('restServerPid', rest_process.pid)
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# Deal with annotation # Deal with annotation
if experiment_config.get('useAnnotation'): if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation') path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
...@@ -546,11 +553,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -546,11 +553,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
# start a new experiment # start a new experiment
print_normal('Starting experiment...') print_normal('Starting experiment...')
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# set debug configuration # set debug configuration
if mode != 'view' and experiment_config.get('debug') is None: if mode != 'view' and experiment_config.get('debug') is None:
experiment_config['debug'] = args.debug experiment_config['debug'] = args.debug
...@@ -567,7 +569,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -567,7 +569,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
raise Exception(ERROR_INFO % 'Restful server stopped!') raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1) exit(1)
if experiment_config.get('nniManagerIp'): if experiment_config.get('nniManagerIp'):
web_ui_url_list = ['{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))] web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
else: else:
web_ui_url_list = get_local_urls(args.port) web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list) nni_config.set_config('webuiUrl', web_ui_url_list)
...@@ -592,24 +594,28 @@ def create_experiment(args): ...@@ -592,24 +594,28 @@ def create_experiment(args):
print_error('Please set correct config path!') print_error('Please set correct config path!')
exit(1) exit(1)
experiment_config = get_yml_content(config_path) experiment_config = get_yml_content(config_path)
try:
config = ExperimentConfig(**experiment_config)
experiment_config = convert.to_v1_yaml(config)
except Exception:
pass
try: try:
validate_all_content(experiment_config, config_path) validate_all_content(experiment_config, config_path)
except Exception as e: except Exception:
print_error(e) print_warning('Validation with V1 schema failed. Trying to convert from V2 format...')
exit(1) try:
config = ExperimentConfig(**experiment_config)
experiment_config = convert.to_v1_yaml(config)
except Exception as e:
print_error(f'Conversion from v2 format failed: {repr(e)}')
try:
validate_all_content(experiment_config, config_path)
except Exception as e:
print_error(f'Config in v1 format validation failed. {repr(e)}')
exit(1)
nni_config.set_config('experimentConfig', experiment_config) nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port) nni_config.set_config('restServerPort', args.port)
try: try:
launch_experiment(args, experiment_config, 'new', experiment_id) launch_experiment(args, experiment_config, 'new', experiment_id)
except Exception as exception: except Exception as exception:
nni_config = Config(experiment_id) restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
restServerPid = nni_config.get_config('restServerPid')
if restServerPid: if restServerPid:
kill_command(restServerPid) kill_command(restServerPid)
print_error(exception) print_error(exception)
...@@ -641,8 +647,7 @@ def manage_stopped_experiment(args, mode): ...@@ -641,8 +647,7 @@ def manage_stopped_experiment(args, mode):
try: try:
launch_experiment(args, experiment_config, mode, experiment_id) launch_experiment(args, experiment_config, mode, experiment_id)
except Exception as exception: except Exception as exception:
nni_config = Config(experiment_id) restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
restServerPid = nni_config.get_config('restServerPid')
if restServerPid: if restServerPid:
kill_command(restServerPid) kill_command(restServerPid)
print_error(exception) print_error(exception)
......
...@@ -105,7 +105,9 @@ def set_default_values(experiment_config): ...@@ -105,7 +105,9 @@ def set_default_values(experiment_config):
experiment_config['maxExecDuration'] = '999d' experiment_config['maxExecDuration'] = '999d'
if experiment_config.get('maxTrialNum') is None: if experiment_config.get('maxTrialNum') is None:
experiment_config['maxTrialNum'] = 99999 experiment_config['maxTrialNum'] = 99999
if experiment_config['trainingServicePlatform'] == 'remote': if experiment_config['trainingServicePlatform'] == 'remote' or \
experiment_config['trainingServicePlatform'] == 'hybrid' and \
'remote' in experiment_config['hybridConfig']['trainingServicePlatforms']:
for index in range(len(experiment_config['machineList'])): for index in range(len(experiment_config['machineList'])):
if experiment_config['machineList'][index].get('port') is None: if experiment_config['machineList'][index].get('port') is None:
experiment_config['machineList'][index]['port'] = 22 experiment_config['machineList'][index]['port'] = 22
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import argparse import argparse
import logging
import os import os
import pkg_resources import pkg_resources
from colorama import init from colorama import init
...@@ -32,6 +33,8 @@ def nni_info(*args): ...@@ -32,6 +33,8 @@ def nni_info(*args):
print('please run "nnictl {positional argument} --help" to see nnictl guidance') print('please run "nnictl {positional argument} --help" to see nnictl guidance')
def parse_args(): def parse_args():
logging.getLogger().setLevel(logging.ERROR)
'''Definite the arguments users need to follow and input''' '''Definite the arguments users need to follow and input'''
parser = argparse.ArgumentParser(prog='nnictl', description='use nnictl command to control nni experiments') parser = argparse.ArgumentParser(prog='nnictl', description='use nnictl command to control nni experiments')
parser.add_argument('--version', '-v', action='store_true') parser.add_argument('--version', '-v', action='store_true')
...@@ -243,12 +246,9 @@ def parse_args(): ...@@ -243,12 +246,9 @@ def parse_args():
def show_messsage_for_nnictl_package(args): def show_messsage_for_nnictl_package(args):
print_error('nnictl package command is replaced by nnictl algo, please run nnictl algo -h to show the usage') print_error('nnictl package command is replaced by nnictl algo, please run nnictl algo -h to show the usage')
parser_package_subparsers = subparsers.add_parser('package', help='control nni tuner and assessor packages').add_subparsers() parser_package_subparsers = subparsers.add_parser('package', help='this argument is replaced by algo', prefix_chars='\n')
parser_package_subparsers.add_parser('install', help='install packages').set_defaults(func=show_messsage_for_nnictl_package) parser_package_subparsers.add_argument('args', nargs=argparse.REMAINDER)
parser_package_subparsers.add_parser('uninstall', help='uninstall packages').set_defaults(func=show_messsage_for_nnictl_package) parser_package_subparsers.set_defaults(func=show_messsage_for_nnictl_package)
parser_package_subparsers.add_parser('show', help='show the information of packages').set_defaults(
func=show_messsage_for_nnictl_package)
parser_package_subparsers.add_parser('list', help='list installed packages').set_defaults(func=show_messsage_for_nnictl_package)
#parse tensorboard command #parse tensorboard command
parser_tensorboard = subparsers.add_parser('tensorboard', help='manage tensorboard') parser_tensorboard = subparsers.add_parser('tensorboard', help='manage tensorboard')
......
...@@ -50,11 +50,9 @@ def update_experiment(): ...@@ -50,11 +50,9 @@ def update_experiment():
for key in experiment_dict.keys(): for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict): if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED': if experiment_dict[key].get('status') != 'STOPPED':
nni_config = Config(key) rest_pid = experiment_dict[key].get('pid')
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
experiment_config.update_experiment(key, 'status', 'STOPPED') experiment_config.update_experiment(key, 'status', 'STOPPED')
experiment_config.update_experiment(key, 'port', None)
continue continue
def check_experiment_id(args, update=True): def check_experiment_id(args, update=True):
...@@ -83,10 +81,10 @@ def check_experiment_id(args, update=True): ...@@ -83,10 +81,10 @@ def check_experiment_id(args, update=True):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key, experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'), experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'], experiment_dict[key]['status'],
experiment_dict[key]['port'], experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'), experiment_dict[key].get('platform'),
experiment_dict[key]['startTime'], time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
experiment_dict[key]['endTime']) time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1) exit(1)
elif not running_experiment_list: elif not running_experiment_list:
...@@ -130,7 +128,7 @@ def parse_ids(args): ...@@ -130,7 +128,7 @@ def parse_ids(args):
return running_experiment_list return running_experiment_list
if args.port is not None: if args.port is not None:
for key in running_experiment_list: for key in running_experiment_list:
if experiment_dict[key]['port'] == args.port: if experiment_dict[key].get('port') == args.port:
result_list.append(key) result_list.append(key)
if args.id and result_list and args.id != result_list[0]: if args.id and result_list and args.id != result_list[0]:
print_error('Experiment id and resful server port not match') print_error('Experiment id and resful server port not match')
...@@ -143,10 +141,10 @@ def parse_ids(args): ...@@ -143,10 +141,10 @@ def parse_ids(args):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key, experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'), experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'], experiment_dict[key]['status'],
experiment_dict[key]['port'], experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'), experiment_dict[key].get('platform'),
experiment_dict[key]['startTime'], time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
experiment_dict[key]['endTime']) time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1) exit(1)
else: else:
...@@ -186,7 +184,7 @@ def get_experiment_port(args): ...@@ -186,7 +184,7 @@ def get_experiment_port(args):
exit(1) exit(1)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
return experiment_dict[experiment_id]['port'] return experiment_dict[experiment_id].get('port')
def convert_time_stamp_to_date(content): def convert_time_stamp_to_date(content):
'''Convert time stamp to date time format''' '''Convert time stamp to date time format'''
...@@ -202,8 +200,9 @@ def convert_time_stamp_to_date(content): ...@@ -202,8 +200,9 @@ def convert_time_stamp_to_date(content):
def check_rest(args): def check_rest(args):
'''check if restful server is running''' '''check if restful server is running'''
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port) running, _ = check_rest_server_quick(rest_port)
if running: if running:
print_normal('Restful server is running...') print_normal('Restful server is running...')
...@@ -220,18 +219,19 @@ def stop_experiment(args): ...@@ -220,18 +219,19 @@ def stop_experiment(args):
if experiment_id_list: if experiment_id_list:
for experiment_id in experiment_id_list: for experiment_id in experiment_id_list:
print_normal('Stopping experiment %s' % experiment_id) print_normal('Stopping experiment %s' % experiment_id)
nni_config = Config(experiment_id) experiment_config = Experiments()
rest_pid = nni_config.get_config('restServerPid') experiment_dict = experiment_config.get_all_experiments()
rest_pid = experiment_dict.get(experiment_id).get('pid')
if rest_pid: if rest_pid:
kill_command(rest_pid) kill_command(rest_pid)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList') tensorboard_pid_list = experiment_dict.get(experiment_id).get('tensorboardPidList')
if tensorboard_pid_list: if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list: for tensorboard_pid in tensorboard_pid_list:
try: try:
kill_command(tensorboard_pid) kill_command(tensorboard_pid)
except Exception as exception: except Exception as exception:
print_error(exception) print_error(exception)
nni_config.set_config('tensorboardPidList', []) experiment_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop experiment success.') print_normal('Stop experiment success.')
def trial_ls(args): def trial_ls(args):
...@@ -250,9 +250,10 @@ def trial_ls(args): ...@@ -250,9 +250,10 @@ def trial_ls(args):
if args.head and args.tail: if args.head and args.tail:
print_error('Head and tail cannot be set at the same time.') print_error('Head and tail cannot be set at the same time.')
return return
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_pid = nni_config.get_config('restServerPid') rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -281,9 +282,10 @@ def trial_ls(args): ...@@ -281,9 +282,10 @@ def trial_ls(args):
def trial_kill(args): def trial_kill(args):
'''List trial''' '''List trial'''
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_pid = nni_config.get_config('restServerPid') rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -312,9 +314,10 @@ def trial_codegen(args): ...@@ -312,9 +314,10 @@ def trial_codegen(args):
def list_experiment(args): def list_experiment(args):
'''Get experiment information''' '''Get experiment information'''
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_pid = nni_config.get_config('restServerPid') rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -333,8 +336,9 @@ def list_experiment(args): ...@@ -333,8 +336,9 @@ def list_experiment(args):
def experiment_status(args): def experiment_status(args):
'''Show the status of experiment''' '''Show the status of experiment'''
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
result, response = check_rest_server_quick(rest_port) result, response = check_rest_server_quick(rest_port)
if not result: if not result:
print_normal('Restful server is not running...') print_normal('Restful server is not running...')
...@@ -620,12 +624,12 @@ def platform_clean(args): ...@@ -620,12 +624,12 @@ def platform_clean(args):
break break
if platform == 'remote': if platform == 'remote':
machine_list = config_content.get('machineList') machine_list = config_content.get('machineList')
remote_clean(machine_list, None) remote_clean(machine_list)
elif platform == 'pai': elif platform == 'pai':
host = config_content.get('paiConfig').get('host') host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName') user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir') output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, None) hdfs_clean(host, user_name, output_dir)
print_normal('Done.') print_normal('Done.')
def experiment_list(args): def experiment_list(args):
...@@ -651,7 +655,7 @@ def experiment_list(args): ...@@ -651,7 +655,7 @@ def experiment_list(args):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key, experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'), experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'], experiment_dict[key]['status'],
experiment_dict[key]['port'], experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'), experiment_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime']) time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
...@@ -752,9 +756,10 @@ def export_trials_data(args): ...@@ -752,9 +756,10 @@ def export_trials_data(args):
groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data'])) groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data']))
return groupby return groupby
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_pid = nni_config.get_config('restServerPid') rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
......
...@@ -70,7 +70,7 @@ def format_tensorboard_log_path(path_list): ...@@ -70,7 +70,7 @@ def format_tensorboard_log_path(path_list):
new_path_list.append('name%d:%s' % (index + 1, value)) new_path_list.append('name%d:%s' % (index + 1, value))
return ','.join(new_path_list) return ','.join(new_path_list)
def start_tensorboard_process(args, nni_config, path_list, temp_nni_path): def start_tensorboard_process(args, experiment_id, path_list, temp_nni_path):
'''call cmds to start tensorboard process in local machine''' '''call cmds to start tensorboard process in local machine'''
if detect_port(args.port): if detect_port(args.port):
print_error('Port %s is used by another process, please reset port!' % str(args.port)) print_error('Port %s is used by another process, please reset port!' % str(args.port))
...@@ -83,20 +83,19 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path): ...@@ -83,20 +83,19 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path):
url_list = get_local_urls(args.port) url_list = get_local_urls(args.port)
print_green('Start tensorboard success!') print_green('Start tensorboard success!')
print_normal('Tensorboard urls: ' + ' '.join(url_list)) print_normal('Tensorboard urls: ' + ' '.join(url_list))
tensorboard_process_pid_list = nni_config.get_config('tensorboardPidList') experiment_config = Experiments()
tensorboard_process_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_process_pid_list is None: if tensorboard_process_pid_list is None:
tensorboard_process_pid_list = [tensorboard_process.pid] tensorboard_process_pid_list = [tensorboard_process.pid]
else: else:
tensorboard_process_pid_list.append(tensorboard_process.pid) tensorboard_process_pid_list.append(tensorboard_process.pid)
nni_config.set_config('tensorboardPidList', tensorboard_process_pid_list) experiment_config.update_experiment(experiment_id, 'tensorboardPidList', tensorboard_process_pid_list)
def stop_tensorboard(args): def stop_tensorboard(args):
'''stop tensorboard''' '''stop tensorboard'''
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() tensorboard_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
nni_config = Config(experiment_id)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
if tensorboard_pid_list: if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list: for tensorboard_pid in tensorboard_pid_list:
try: try:
...@@ -104,7 +103,7 @@ def stop_tensorboard(args): ...@@ -104,7 +103,7 @@ def stop_tensorboard(args):
call(cmds) call(cmds)
except Exception as exception: except Exception as exception:
print_error(exception) print_error(exception)
nni_config.set_config('tensorboardPidList', []) experiment_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop tensorboard success!') print_normal('Stop tensorboard success!')
else: else:
print_error('No tensorboard configuration!') print_error('No tensorboard configuration!')
...@@ -164,4 +163,4 @@ def start_tensorboard(args): ...@@ -164,4 +163,4 @@ def start_tensorboard(args):
os.makedirs(temp_nni_path, exist_ok=True) os.makedirs(temp_nni_path, exist_ok=True)
path_list = get_path_list(args, nni_config, trial_content, temp_nni_path) path_list = get_path_list(args, nni_config, trial_content, temp_nni_path)
start_tensorboard_process(args, nni_config, path_list, temp_nni_path) start_tensorboard_process(args, experiment_id, path_list, temp_nni_path)
\ No newline at end of file
...@@ -5,7 +5,7 @@ import json ...@@ -5,7 +5,7 @@ import json
import os import os
from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url, import_data_url from .url_utils import experiment_url, import_data_url
from .config_utils import Config from .config_utils import Config, Experiments
from .common_utils import get_json_content, print_normal, print_error, print_warning from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import get_experiment_port, get_config_filename, detect_process from .nnictl_utils import get_experiment_port, get_config_filename, detect_process
from .launcher_utils import parse_time from .launcher_utils import parse_time
...@@ -58,8 +58,9 @@ def get_query_type(key): ...@@ -58,8 +58,9 @@ def get_query_type(key):
def update_experiment_profile(args, key, value): def update_experiment_profile(args, key, value):
'''call restful server to update experiment profile''' '''call restful server to update experiment profile'''
nni_config = Config(get_config_filename(args)) experiment_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port) running, _ = check_rest_server_quick(rest_port)
if running: if running:
response = rest_get(experiment_url(rest_port), REST_TIME_OUT) response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
......
...@@ -4,27 +4,26 @@ ...@@ -4,27 +4,26 @@
jobs: jobs:
- job: ubuntu_latest - job: ubuntu_latest
pool: pool:
# FIXME: In ubuntu-20.04 Python interpreter crashed during SMAC UT vmImage: ubuntu-latest
vmImage: ubuntu-18.04
# This platform tests lint and doc first. # This platform tests lint and doc first.
steps: steps:
- task: UsePythonVersion@0 - task: UsePythonVersion@0
inputs: inputs:
versionSpec: 3.6 versionSpec: 3.8
displayName: Configure Python version displayName: Configure Python version
- script: | - script: |
set -e set -e
python3 -m pip install --upgrade pip setuptools python -m pip install --upgrade pip setuptools
python3 -m pip install pytest coverage python -m pip install pytest coverage
python3 -m pip install pylint flake8 python -m pip install pylint flake8
echo "##vso[task.setvariable variable=PATH]${HOME}/.local/bin:${PATH}" echo "##vso[task.setvariable variable=PATH]${HOME}/.local/bin:${PATH}"
displayName: Install Python tools displayName: Install Python tools
- script: | - script: |
python3 setup.py develop python setup.py develop
displayName: Install NNI displayName: Install NNI
- script: | - script: |
...@@ -35,24 +34,28 @@ jobs: ...@@ -35,24 +34,28 @@ jobs:
yarn eslint yarn eslint
displayName: ESLint displayName: ESLint
# FIXME: temporarily fixed to pytorch 1.6 as 1.7 won't work with compression
- script: | - script: |
set -e set -e
sudo apt-get install -y pandoc sudo apt-get install -y pandoc
python3 -m pip install --upgrade pygments python -m pip install --upgrade pygments
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html python -m pip install "torch==1.6.0+cpu" "torchvision==0.7.0+cpu" -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade tensorflow python -m pip install tensorflow
python3 -m pip install --upgrade gym onnx peewee thop graphviz python -m pip install gym onnx peewee thop graphviz
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx python -m pip install sphinx==3.3.1 sphinx-argparse==0.2.5 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 nbsphinx
sudo apt-get install swig -y sudo apt-get remove swig -y
python3 -m pip install -e .[SMAC,BOHB] sudo apt-get install swig3.0 -y
sudo ln -s /usr/bin/swig3.0 /usr/bin/swig
python -m pip install -e .[SMAC,BOHB]
displayName: Install extra dependencies displayName: Install extra dependencies
- script: | - script: |
set -e set -e
python3 -m pylint --rcfile pylintrc nni python -m pylint --rcfile pylintrc nni
python3 -m flake8 nni --count --select=E9,F63,F72,F82 --show-source --statistics python -m flake8 nni --count --select=E9,F63,F72,F82 --show-source --statistics
EXCLUDES=examples/trials/mnist-nas/*/mnist*.py,examples/trials/nas_cifar10/src/cifar10/general_child.py EXCLUDES=examples/trials/mnist-nas/*/mnist*.py,examples/trials/nas_cifar10/src/cifar10/general_child.py
python3 -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics python -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
displayName: pylint and flake8 displayName: pylint and flake8
- script: | - script: |
...@@ -61,10 +64,14 @@ jobs: ...@@ -61,10 +64,14 @@ jobs:
displayName: Check Sphinx documentation displayName: Check Sphinx documentation
- script: | - script: |
set -e
cd test cd test
python3 -m pytest ut --ignore=ut/sdk/test_pruners.py --ignore=ut/sdk/test_compressor_tf.py python -m pytest ut --ignore=ut/sdk/test_pruners.py \
python3 -m pytest ut/sdk/test_pruners.py --ignore=ut/sdk/test_compressor_tf.py \
python3 -m pytest ut/sdk/test_compressor_tf.py --ignore=ut/sdk/test_compressor_torch.py
python -m pytest ut/sdk/test_pruners.py
python -m pytest ut/sdk/test_compressor_tf.py
python -m pytest ut/sdk/test_compressor_torch.py
displayName: Python unit test displayName: Python unit test
- script: | - script: |
...@@ -77,7 +84,7 @@ jobs: ...@@ -77,7 +84,7 @@ jobs:
- script: | - script: |
cd test cd test
python3 nni_test/nnitest/run_tests.py --config config/pr_tests.yml python nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName: Simple integration test displayName: Simple integration test
......
trigger: none
pr: none
schedules:
- cron: 0 16 * * *
branches:
include: [ master ]
jobs:
- job: adl
pool: NNI CI KUBE CLI
timeoutInMinutes: 120
steps:
- script: |
export NNI_RELEASE=999.$(date -u +%Y%m%d%H%M%S)
echo "##vso[task.setvariable variable=PATH]${PATH}:${HOME}/.local/bin"
echo "##vso[task.setvariable variable=NNI_RELEASE]${NNI_RELEASE}"
echo "Working directory: ${PWD}"
echo "NNI version: ${NNI_RELEASE}"
echo "Build docker image: $(build_docker_image)"
python3 -m pip install --upgrade pip setuptools
displayName: Prepare
- script: |
set -e
python3 setup.py build_ts
python3 setup.py bdist_wheel -p manylinux1_x86_64
python3 -m pip install dist/nni-${NNI_RELEASE}-py3-none-manylinux1_x86_64.whl[SMAC,BOHB]
displayName: Build and install NNI
- script: |
set -e
cd examples/tuners/customized_tuner
python3 setup.py develop --user
nnictl algo register --meta meta_file.yml
displayName: Install customized tuner
- script: |
set -e
docker login -u nnidev -p $(docker_hub_password)
sed -i '$a RUN python3 -m pip install adaptdl tensorboard' Dockerfile
sed -i '$a COPY examples /examples' Dockerfile
sed -i '$a COPY test /test' Dockerfile
echo '## Build docker image ##'
docker build --build-arg NNI_RELEASE=${NNI_RELEASE} -t nnidev/nni-nightly .
echo '## Upload docker image ##'
docker push nnidev/nni-nightly
condition: eq(variables['build_docker_image'], 'true')
displayName: Build and upload docker image
- script: |
set -e
cd test
python3 nni_test/nnitest/generate_ts_config.py \
--ts adl \
--nni_docker_image nnidev/nni-nightly \
--checkpoint_storage_class $(checkpoint_storage_class) \
--checkpoint_storage_size $(checkpoint_storage_size) \
--nni_manager_ip $(nni_manager_ip)
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts adl
displayName: Integration test
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
trigger: none
pr: none
jobs: jobs:
- job: validate_version_number - job: validate_version_number
pool: pool:
...@@ -13,9 +16,11 @@ jobs: ...@@ -13,9 +16,11 @@ jobs:
displayName: Configure Python version displayName: Configure Python version
- script: | - script: |
echo $(build_type)
echo $(NNI_RELEASE)
export BRANCH_TAG=`git describe --tags --abbrev=0` export BRANCH_TAG=`git describe --tags --abbrev=0`
echo $BRANCH_TAG echo $BRANCH_TAG
if [[ $BRANCH_TAG = v$(NNI_RELEASE) && $(NNI_RELEASE) =~ ^v[0-9](.[0-9])+$ ]]; then if [[ $BRANCH_TAG == v$(NNI_RELEASE) && $(NNI_RELEASE) =~ ^[0-9](.[0-9])+$ ]]; then
echo 'Build version match branch tag' echo 'Build version match branch tag'
else else
echo 'Build version does not match branch tag' echo 'Build version does not match branch tag'
...@@ -25,15 +30,16 @@ jobs: ...@@ -25,15 +30,16 @@ jobs:
displayName: Validate release version number and branch tag displayName: Validate release version number and branch tag
- script: | - script: |
echo $(build_type)
echo $(NNI_RELEASE) echo $(NNI_RELEASE)
if [[ $(NNI_RELEASE) =~ ^[0-9](.[0-9])+a[0-9]$ ]]; then if [[ $(NNI_RELEASE) =~ ^[0-9](.[0-9])+(a|b|rc)[0-9]$ ]]; then
echo 'Valid prerelease version $(NNI_RELEASE)' echo 'Valid prerelease version $(NNI_RELEASE)'
echo `git describe --tags --abbrev=0` echo `git describe --tags --abbrev=0`
else else
echo 'Invalid build version $(NNI_RELEASE)' echo 'Invalid build version $(NNI_RELEASE)'
exit 1 exit 1
fi fi
condition: ne( variables['build_type'], 'rerelease' ) condition: ne( variables['build_type'], 'release' )
displayName: Validate prerelease version number displayName: Validate prerelease version number
- job: linux - job: linux
...@@ -49,22 +55,22 @@ jobs: ...@@ -49,22 +55,22 @@ jobs:
displayName: Configure Python version displayName: Configure Python version
- script: | - script: |
python -m pip install --upgrade pip setuptools twine python -m pip install --upgrade pip setuptools wheel twine
python test/vso_tools/build_wheel.py $(NNI_RELEASE) python test/vso_tools/build_wheel.py $(NNI_RELEASE)
displayName: Build wheel
if [ $(build_type) = 'release' ] - script: |
echo 'uploading release package to pypi...' if [[ $(build_type) == 'release' || $(build_type) == 'rc' ]]; then
echo 'uploading to pypi...'
python -m twine upload -u nni -p $(pypi_password) dist/* python -m twine upload -u nni -p $(pypi_password) dist/*
then
else else
echo 'uploading prerelease package to testpypi...' echo 'uploading to testpypi...'
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/* python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
fi fi
displayName: Build and upload wheel displayName: Upload wheel
- script: | - script: |
if [ $(build_type) = 'release' ] if [[ $(build_type) == 'release' || $(build_type) == 'rc' ]]; then
then
docker login -u msranni -p $(docker_hub_password) docker login -u msranni -p $(docker_hub_password)
export IMAGE_NAME=msranni/nni export IMAGE_NAME=msranni/nni
else else
...@@ -74,9 +80,11 @@ jobs: ...@@ -74,9 +80,11 @@ jobs:
echo "## Building ${IMAGE_NAME}:$(NNI_RELEASE) ##" echo "## Building ${IMAGE_NAME}:$(NNI_RELEASE) ##"
docker build --build-arg NNI_RELEASE=$(NNI_RELEASE) -t ${IMAGE_NAME} . docker build --build-arg NNI_RELEASE=$(NNI_RELEASE) -t ${IMAGE_NAME} .
docker tag ${IMAGE_NAME} ${IMAGE_NAME}:$(NNI_RELEASE) docker tag ${IMAGE_NAME} ${IMAGE_NAME}:v$(NNI_RELEASE)
docker push ${IMAGE_NAME} docker push ${IMAGE_NAME}:v$(NNI_RELEASE)
docker push ${IMAGE_NAME}:$(NNI_RELEASE) if [[ $(build_type) != 'rc' ]]; then
docker push ${IMAGE_NAME}
fi
displayName: Build and upload docker image displayName: Build and upload docker image
- job: macos - job: macos
...@@ -92,18 +100,19 @@ jobs: ...@@ -92,18 +100,19 @@ jobs:
displayName: Configure Python version displayName: Configure Python version
- script: | - script: |
python -m pip install --upgrade pip setuptools twine python -m pip install --upgrade pip setuptools wheel twine
python test/vso_tools/build_wheel.py $(NNI_RELEASE) python test/vso_tools/build_wheel.py $(NNI_RELEASE)
displayName: Build wheel
if [ $(build_type) = 'release' ] - script: |
if [[ $(build_type) == 'release' || $(build_type) == 'rc' ]]; then
echo '## uploading to pypi ##' echo '## uploading to pypi ##'
python -m twine upload -u nni -p $(pypi_password) dist/* python -m twine upload -u nni -p $(pypi_password) dist/*
then
else else
echo '## uploading to testpypi ##' echo '## uploading to testpypi ##'
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/* python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
fi fi
displayName: Build and upload wheel displayName: Upload wheel
- job: windows - job: windows
dependsOn: validate_version_number dependsOn: validate_version_number
...@@ -118,15 +127,16 @@ jobs: ...@@ -118,15 +127,16 @@ jobs:
displayName: Configure Python version displayName: Configure Python version
- powershell: | - powershell: |
python -m pip install --upgrade pip setuptools twine python -m pip install --upgrade pip setuptools wheel twine
python test/vso_tools/build_wheel.py $(NNI_RELEASE) python test/vso_tools/build_wheel.py $(NNI_RELEASE)
displayName: Build wheel
if($env:BUILD_TYPE -eq 'release'){ - powershell: |
if ($env:BUILD_TYPE -eq 'release' -Or $env:BUILD_TYPE -eq 'rc') {
Write-Host '## uploading to pypi ##' Write-Host '## uploading to pypi ##'
python -m twine upload -u nni -p $(pypi_password) dist/* python -m twine upload -u nni -p $(pypi_password) dist/*
} } else {
else{
Write-Host '## uploading to testpypi ##' Write-Host '## uploading to testpypi ##'
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/* python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
} }
displayName: Build and upload wheel displayName: Upload wheel
...@@ -69,7 +69,6 @@ dependencies = [ ...@@ -69,7 +69,6 @@ dependencies = [
'PythonWebHDFS', 'PythonWebHDFS',
'colorama', 'colorama',
'scikit-learn>=0.23.2', 'scikit-learn>=0.23.2',
'pkginfo',
'websockets', 'websockets',
'filelock', 'filelock',
'prettytable', 'prettytable',
...@@ -112,11 +111,8 @@ def _setup(): ...@@ -112,11 +111,8 @@ def _setup():
python_requires = '>=3.6', python_requires = '>=3.6',
install_requires = dependencies, install_requires = dependencies,
extras_require = { extras_require = {
'SMAC': [ 'SMAC': ['ConfigSpaceNNI', 'smac4nni'],
'ConfigSpaceNNI @ git+https://github.com/QuanluZhang/ConfigSpace.git', 'BOHB': ['ConfigSpace==0.4.7', 'statsmodels==0.12.0'],
'smac @ git+https://github.com/QuanluZhang/SMAC3.git'
],
'BOHB': ['ConfigSpace==0.4.7', 'statsmodels==0.10.0'],
'PPOTuner': ['enum34', 'gym'] 'PPOTuner': ['enum34', 'gym']
}, },
setup_requires = ['requests'], setup_requires = ['requests'],
...@@ -189,6 +185,7 @@ class Build(build): ...@@ -189,6 +185,7 @@ class Build(build):
sys.exit('Please set environment variable "NNI_RELEASE=<release_version>"') sys.exit('Please set environment variable "NNI_RELEASE=<release_version>"')
if os.path.islink('nni_node/main.js'): if os.path.islink('nni_node/main.js'):
sys.exit('A development build already exists. Please uninstall NNI and run "python3 setup.py clean --all".') sys.exit('A development build already exists. Please uninstall NNI and run "python3 setup.py clean --all".')
open('nni/version.py', 'w').write(f"__version__ = '{release}'")
super().run() super().run()
class Develop(develop): class Develop(develop):
...@@ -212,6 +209,7 @@ class Develop(develop): ...@@ -212,6 +209,7 @@ class Develop(develop):
super().finalize_options() super().finalize_options()
def run(self): def run(self):
open('nni/version.py', 'w').write("__version__ = '999.dev0'")
if not self.skip_ts: if not self.skip_ts:
setup_ts.build(release=None) setup_ts.build(release=None)
super().run() super().run()
......
...@@ -196,6 +196,7 @@ def copy_nni_node(version): ...@@ -196,6 +196,7 @@ def copy_nni_node(version):
package_json['version'] = version package_json['version'] = version
json.dump(package_json, open('nni_node/package.json', 'w'), indent=2) json.dump(package_json, open('nni_node/package.json', 'w'), indent=2)
# reinstall without development dependencies
_yarn('ts/nni_manager', '--prod', '--cwd', str(Path('nni_node').resolve())) _yarn('ts/nni_manager', '--prod', '--cwd', str(Path('nni_node').resolve()))
shutil.copytree('ts/webui/build', 'nni_node/static') shutil.copytree('ts/webui/build', 'nni_node/static')
...@@ -226,9 +227,9 @@ def _symlink(target_file, link_location): ...@@ -226,9 +227,9 @@ def _symlink(target_file, link_location):
def _print(*args): def _print(*args):
if sys.platform == 'win32': if sys.platform == 'win32':
print(*args) print(*args, flush=True)
else: else:
print('\033[1;36m#', *args, '\033[0m') print('\033[1;36m#', *args, '\033[0m', flush=True)
generated_files = [ generated_files = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment