Unverified Commit 4784cc6c authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Merge pull request #3302 from microsoft/v2.0-merge

Merge branch v2.0 into master (no squash)
parents 25db55ca 349ead41
from .interface import BaseTrainer
from .interface import BaseTrainer, BaseOneShotTrainer
from .pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
......@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any:
return None
@register_trainer()
@register_trainer
class PyTorchImageClassificationTrainer(BaseTrainer):
"""
Image classification trainer for PyTorch.
......@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super(PyTorchImageClassificationTrainer, self).__init__()
super().__init__()
self._use_cuda = torch.cuda.is_available()
self.model = model
if self._use_cuda:
......
......@@ -6,6 +6,7 @@ import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
......@@ -17,13 +18,14 @@ _logger = logging.getLogger(__name__)
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs):
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * self.alpha.view(*alpha_shape), 0)
return torch.sum(op_results * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
......@@ -42,13 +44,14 @@ class DartsLayerChoice(nn.Module):
class DartsInputChoice(nn.Module):
def __init__(self, input_choice):
super(DartsInputChoice, self).__init__()
self.name = input_choice.key
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1
def forward(self, inputs):
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * self.alpha.view(*alpha_shape), 0)
return torch.sum(inputs * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
......@@ -123,7 +126,15 @@ class DartsTrainer(BaseOneShotTrainer):
module.to(self.device)
self.model_optim = optimizer
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate, betas=(0.5, 0.999),
# use the same architecture weight for modules with duplicated names
ctrl_params = {}
for _, m in self.nas_modules:
if m.name in ctrl_params:
assert m.alpha.size() == ctrl_params[m.name].size(), 'Size of parameters with the same label should be same.'
m.alpha = ctrl_params[m.name]
else:
ctrl_params[m.name] = m.alpha
self.ctrl_optim = torch.optim.Adam(list(ctrl_params.values()), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3)
self.unrolled = unrolled
self.grad_clip = 5.
......
......@@ -157,6 +157,7 @@ class ProxylessTrainer(BaseOneShotTrainer):
module.to(self.device)
self.optimizer = optimizer
# we do not support deduplicate control parameters with same label (like DARTS) yet.
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate,
weight_decay=0, betas=(0, 0.999), eps=1e-8)
self._init_dataloader()
......
......@@ -6,6 +6,7 @@ from collections import OrderedDict
import numpy as np
import torch
import nni.retiarii.nn.pytorch as nn
from nni.nas.pytorch.mutables import InputChoice, LayerChoice
_logger = logging.getLogger(__name__)
......@@ -157,7 +158,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, LayerChoice, modules)
return _replace_module_with_type(root_module, init_fn, (LayerChoice, nn.LayerChoice), modules)
def replace_input_choice(root_module, init_fn, modules=None):
......@@ -178,4 +179,4 @@ def replace_input_choice(root_module, init_fn, modules=None):
List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, InputChoice, modules)
return _replace_module_with_type(root_module, init_fn, (InputChoice, nn.InputChoice), modules)
import inspect
import warnings
from collections import defaultdict
from typing import Any
......@@ -11,6 +12,13 @@ def import_(target: str, allow_none: bool = False) -> Any:
return getattr(module, identifier)
def version_larger_equal(a: str, b: str) -> bool:
# TODO: refactor later
a = a.split('+')[0]
b = b.split('+')[0]
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
_records = {}
......@@ -19,6 +27,11 @@ def get_records():
return _records
def clear_records():
global _records
_records = {}
def add_record(key, value):
"""
"""
......@@ -28,69 +41,83 @@ def add_record(key, value):
_records[key] = value
def _register_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
def del_record(key):
global _records
if _records is not None:
_records.pop(key, None)
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
def _blackbox_cls(cls, module_name, register_format=None):
class wrapper(cls):
def __init__(self, *args, **kwargs):
argname_list = list(inspect.signature(cls).parameters.keys())
full_args = {}
full_args.update(kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
assert len(args) <= len(argname_list), f'Length of {args} is greater than length of {argname_list}.'
for argname, value in zip(argname_list, args):
full_args[argname] = value
# eject un-serializable arguments
for k in list(full_args.keys()):
# The list is not complete and does not support nested cases.
if not isinstance(full_args[k], (int, float, str, dict, list, tuple)):
if not (register_format == 'full' and k == 'model'):
# no warning if it is base model in trainer
warnings.warn(f'{cls} has un-serializable arguments {k} whose value is {full_args[k]}. \
This is not supported. You can ignore this warning if you are passing the model to trainer.')
full_args.pop(k)
def register_module():
"""
Register a module.
"""
# use it as a decorator: @register_module()
def _register(cls):
m = _register_module(
original_class=cls)
return m
if register_format == 'args':
add_record(id(self), full_args)
elif register_format == 'full':
full_class_name = cls.__module__ + '.' + cls.__name__
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
super().__init__(*args, **kwargs)
return _register
def __del__(self):
del_record(id(self))
# using module_name instead of cls.__module__ because it's more natural to see where the module gets wrapped
# instead of simply putting torch.nn or etc.
wrapper.__module__ = module_name
wrapper.__name__ = cls.__name__
wrapper.__qualname__ = cls.__qualname__
wrapper.__init__.__doc__ = cls.__init__.__doc__
def _register_trainer(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
return wrapper
full_class_name = original_class.__module__ + '.' + original_class.__name__
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
# TODO: support both pytorch and tensorflow
from .nn.pytorch import Module
if isinstance(args[i], Module):
# ignore the base model object
continue
full_args[argname_list[i]] = arg
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
def blackbox(cls, *args, **kwargs):
"""
To create an blackbox instance inline without decorator. For example,
orig_init(self, *args, **kws) # Call the original __init__
.. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
# get caller module name
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')(*args, **kwargs)
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def blackbox_module(cls):
"""
Register a module. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'args')
def register_trainer():
def _register(cls):
m = _register_trainer(
original_class=cls)
return m
return _register
def register_trainer(cls):
"""
Register a trainer. Use it as a decorator.
"""
frm = inspect.stack()[1]
module_name = inspect.getmodule(frm[0]).__name__
return _blackbox_cls(cls, module_name, 'full')
_last_uid = defaultdict(int)
......
......@@ -12,6 +12,13 @@ import colorama
from .env_vars import dispatcher_env_vars, trial_env_vars
handlers = {}
log_format = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
time_format = '%Y-%m-%d %H:%M:%S'
formatter = Formatter(log_format, time_format)
def init_logger() -> None:
"""
This function will (and should only) get invoked on the first time of importing nni (no matter which submodule).
......@@ -37,6 +44,8 @@ def init_logger() -> None:
_init_logger_standalone()
logging.getLogger('filelock').setLevel(logging.WARNING)
def init_logger_experiment() -> None:
"""
......@@ -44,15 +53,19 @@ def init_logger_experiment() -> None:
This function will get invoked after `init_logger()`.
"""
formatter.format = _colorful_format
colorful_formatter = Formatter(log_format, time_format)
colorful_formatter.format = _colorful_format
handlers['_default_'].setFormatter(colorful_formatter)
def start_experiment_log(experiment_id: str, log_directory: Path, debug: bool) -> None:
log_path = _prepare_log_dir(log_directory) / 'dispatcher.log'
log_level = logging.DEBUG if debug else logging.INFO
_register_handler(FileHandler(log_path), log_level, experiment_id)
time_format = '%Y-%m-%d %H:%M:%S'
def stop_experiment_log(experiment_id: str) -> None:
if experiment_id in handlers:
logging.getLogger().removeHandler(handlers.pop(experiment_id))
formatter = Formatter(
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s',
time_format
)
def _init_logger_dispatcher() -> None:
log_level_map = {
......@@ -66,26 +79,20 @@ def _init_logger_dispatcher() -> None:
log_path = _prepare_log_dir(dispatcher_env_vars.NNI_LOG_DIRECTORY) / 'dispatcher.log'
log_level = log_level_map.get(dispatcher_env_vars.NNI_LOG_LEVEL, logging.INFO)
_setup_root_logger(FileHandler(log_path), log_level)
_register_handler(FileHandler(log_path), log_level)
def _init_logger_trial() -> None:
log_path = _prepare_log_dir(trial_env_vars.NNI_OUTPUT_DIR) / 'trial.log'
log_file = open(log_path, 'w')
_setup_root_logger(StreamHandler(log_file), logging.INFO)
_register_handler(StreamHandler(log_file), logging.INFO)
if trial_env_vars.NNI_PLATFORM == 'local':
sys.stdout = _LogFileWrapper(log_file)
def _init_logger_standalone() -> None:
_setup_nni_logger(StreamHandler(sys.stdout), logging.INFO)
# Following line does not affect NNI loggers, but without this user's logger won't
# print log even it's level is set to INFO, so we do it for user's convenience.
# If this causes any issue in future, remove it and use `logging.info()` instead of
# `logging.getLogger('xxx').info()` in all examples.
logging.basicConfig()
_register_handler(StreamHandler(sys.stdout), logging.INFO)
def _prepare_log_dir(path: Optional[str]) -> Path:
......@@ -95,20 +102,18 @@ def _prepare_log_dir(path: Optional[str]) -> Path:
ret.mkdir(parents=True, exist_ok=True)
return ret
def _setup_root_logger(handler: Handler, level: int) -> None:
_setup_logger('', handler, level)
def _setup_nni_logger(handler: Handler, level: int) -> None:
_setup_logger('nni', handler, level)
def _setup_logger(name: str, handler: Handler, level: int) -> None:
def _register_handler(handler: Handler, level: int, tag: str = '_default_') -> None:
assert tag not in handlers
handlers[tag] = handler
handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger = logging.getLogger()
logger.addHandler(handler)
logger.setLevel(level)
logger.propagate = False
def _colorful_format(record):
time = formatter.formatTime(record, time_format)
if not record.name.startswith('nni.'):
return '[{}] ({}) {}'.format(time, record.name, record.msg % record.args)
if record.levelno >= logging.ERROR:
color = colorama.Fore.RED
elif record.levelno >= logging.WARNING:
......@@ -118,7 +123,6 @@ def _colorful_format(record):
else:
color = colorama.Fore.BLUE
msg = color + (record.msg % record.args) + colorama.Style.RESET_ALL
time = formatter.formatTime(record, time_format)
if record.levelno < logging.INFO:
return '[{}] {}:{} {}'.format(time, record.threadName, record.name, msg)
else:
......
......@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from .standalone import *
elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import *
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'):
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'hybrid'):
from .local import *
else:
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
......@@ -124,7 +124,7 @@ common_schema = {
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice(
'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'),
'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'hybrid'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool),
......@@ -262,7 +262,7 @@ aml_config_schema = {
}
}
heterogeneous_trial_schema = {
hybrid_trial_schema = {
'trial': {
'codeDir': setPathCheck('codeDir'),
Optional('nniManagerNFSMountPath'): setPathCheck('nniManagerNFSMountPath'),
......@@ -279,8 +279,8 @@ heterogeneous_trial_schema = {
}
}
heterogeneous_config_schema = {
'heterogeneousConfig': {
hybrid_config_schema = {
'hybridConfig': {
'trainingServicePlatforms': ['local', 'remote', 'pai', 'aml']
}
}
......@@ -461,7 +461,7 @@ training_service_schema_dict = {
'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}),
'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}),
'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}),
'heterogeneous': Schema({**common_schema, **heterogeneous_trial_schema, **heterogeneous_config_schema, **machine_list_schema,
'hybrid': Schema({**common_schema, **hybrid_trial_schema, **hybrid_config_schema, **machine_list_schema,
**pai_config_schema, **aml_config_schema, **remote_config_schema}),
}
......@@ -479,7 +479,7 @@ class NNIConfigSchema:
self.validate_pai_trial_conifg(experiment_config)
self.validate_kubeflow_operators(experiment_config)
self.validate_eth0_device(experiment_config)
self.validate_heterogeneous_platforms(experiment_config)
self.validate_hybrid_platforms(experiment_config)
def validate_tuner_adivosr_assessor(self, experiment_config):
if experiment_config.get('advisor'):
......@@ -590,15 +590,15 @@ class NNIConfigSchema:
and 'eth0' not in netifaces.interfaces():
raise SchemaError('This machine does not contain eth0 network device, please set nniManagerIp in config file!')
def validate_heterogeneous_platforms(self, experiment_config):
def validate_hybrid_platforms(self, experiment_config):
required_config_name_map = {
'remote': 'machineList',
'aml': 'amlConfig',
'pai': 'paiConfig'
}
if experiment_config.get('trainingServicePlatform') == 'heterogeneous':
for platform in experiment_config['heterogeneousConfig']['trainingServicePlatforms']:
if experiment_config.get('trainingServicePlatform') == 'hybrid':
for platform in experiment_config['hybridConfig']['trainingServicePlatforms']:
config_name = required_config_name_map.get(platform)
if config_name and not experiment_config.get(config_name):
raise SchemaError('Need to set {0} for {1} in heterogeneous mode!'.format(config_name, platform))
raise SchemaError('Need to set {0} for {1} in hybrid mode!'.format(config_name, platform))
\ No newline at end of file
......@@ -17,7 +17,7 @@ from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_response
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls
from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, \
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \
detect_port, get_user
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
......@@ -47,10 +47,10 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
'You could use \'nnictl create --help\' to get help information' % port)
exit(1)
if (platform != 'local') and detect_port(int(port) + 1):
print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
if (platform not in ['local', 'aml']) and detect_port(int(port) + 1):
print_error('%s mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
'You could set another port to start experiment!\n' \
'You could use \'nnictl create --help\' to get help information' % ((int(port) + 1), (int(port) + 1)))
'You could use \'nnictl create --help\' to get help information' % (platform, (int(port) + 1), (int(port) + 1)))
exit(1)
print_normal('Starting restful server...')
......@@ -300,23 +300,25 @@ def set_aml_config(experiment_config, port, config_file_name):
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_heterogeneous_config(experiment_config, port, config_file_name):
'''set heterogeneous configuration'''
heterogeneous_config_data = dict()
heterogeneous_config_data['heterogeneous_config'] = experiment_config['heterogeneousConfig']
platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms']
def set_hybrid_config(experiment_config, port, config_file_name):
'''set hybrid configuration'''
hybrid_config_data = dict()
hybrid_config_data['hybrid_config'] = experiment_config['hybridConfig']
platform_list = experiment_config['hybridConfig']['trainingServicePlatforms']
for platform in platform_list:
if platform == 'aml':
heterogeneous_config_data['aml_config'] = experiment_config['amlConfig']
hybrid_config_data['aml_config'] = experiment_config['amlConfig']
elif platform == 'remote':
if experiment_config.get('remoteConfig'):
heterogeneous_config_data['remote_config'] = experiment_config['remoteConfig']
heterogeneous_config_data['machine_list'] = experiment_config['machineList']
hybrid_config_data['remote_config'] = experiment_config['remoteConfig']
hybrid_config_data['machine_list'] = experiment_config['machineList']
elif platform == 'local' and experiment_config.get('localConfig'):
heterogeneous_config_data['local_config'] = experiment_config['localConfig']
hybrid_config_data['local_config'] = experiment_config['localConfig']
elif platform == 'pai':
heterogeneous_config_data['pai_config'] = experiment_config['paiConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(heterogeneous_config_data), REST_TIME_OUT)
hybrid_config_data['pai_config'] = experiment_config['paiConfig']
# It needs to connect all remote machines, set longer timeout here to wait for restful server connection response.
time_out = 60 if 'remote' in platform_list else REST_TIME_OUT
response = rest_put(cluster_metadata_url(port), json.dumps(hybrid_config_data), time_out)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
......@@ -412,10 +414,10 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{'key': 'aml_config', 'value': experiment_config['amlConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'heterogeneous':
elif experiment_config['trainingServicePlatform'] == 'hybrid':
request_data['clusterMetaData'].append(
{'key': 'heterogeneous_config', 'value': experiment_config['heterogeneousConfig']})
platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms']
{'key': 'hybrid_config', 'value': experiment_config['hybridConfig']})
platform_list = experiment_config['hybridConfig']['trainingServicePlatforms']
request_dict = {
'aml': {'key': 'aml_config', 'value': experiment_config.get('amlConfig')},
'remote': {'key': 'machine_list', 'value': experiment_config.get('machineList')},
......@@ -460,8 +462,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result, err_msg = set_dlts_config(experiment_config, port, config_file_name)
elif platform == 'aml':
config_result, err_msg = set_aml_config(experiment_config, port, config_file_name)
elif platform == 'heterogeneous':
config_result, err_msg = set_heterogeneous_config(experiment_config, port, config_file_name)
elif platform == 'hybrid':
config_result, err_msg = set_hybrid_config(experiment_config, port, config_file_name)
else:
raise Exception(ERROR_INFO % 'Unsupported platform!')
exit(1)
......@@ -509,6 +511,11 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, experiment_id, foreground, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid)
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# Deal with annotation
if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
......@@ -546,11 +553,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
# start a new experiment
print_normal('Starting experiment...')
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# set debug configuration
if mode != 'view' and experiment_config.get('debug') is None:
experiment_config['debug'] = args.debug
......@@ -567,7 +569,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
if experiment_config.get('nniManagerIp'):
web_ui_url_list = ['{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
else:
web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list)
......@@ -592,24 +594,28 @@ def create_experiment(args):
print_error('Please set correct config path!')
exit(1)
experiment_config = get_yml_content(config_path)
try:
config = ExperimentConfig(**experiment_config)
experiment_config = convert.to_v1_yaml(config)
except Exception:
pass
try:
validate_all_content(experiment_config, config_path)
except Exception as e:
print_error(e)
exit(1)
except Exception:
print_warning('Validation with V1 schema failed. Trying to convert from V2 format...')
try:
config = ExperimentConfig(**experiment_config)
experiment_config = convert.to_v1_yaml(config)
except Exception as e:
print_error(f'Conversion from v2 format failed: {repr(e)}')
try:
validate_all_content(experiment_config, config_path)
except Exception as e:
print_error(f'Config in v1 format validation failed. {repr(e)}')
exit(1)
nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port)
try:
launch_experiment(args, experiment_config, 'new', experiment_id)
except Exception as exception:
nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid')
restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
if restServerPid:
kill_command(restServerPid)
print_error(exception)
......@@ -641,8 +647,7 @@ def manage_stopped_experiment(args, mode):
try:
launch_experiment(args, experiment_config, mode, experiment_id)
except Exception as exception:
nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid')
restServerPid = Experiments().get_all_experiments().get(experiment_id, {}).get('pid')
if restServerPid:
kill_command(restServerPid)
print_error(exception)
......
......@@ -105,7 +105,9 @@ def set_default_values(experiment_config):
experiment_config['maxExecDuration'] = '999d'
if experiment_config.get('maxTrialNum') is None:
experiment_config['maxTrialNum'] = 99999
if experiment_config['trainingServicePlatform'] == 'remote':
if experiment_config['trainingServicePlatform'] == 'remote' or \
experiment_config['trainingServicePlatform'] == 'hybrid' and \
'remote' in experiment_config['hybridConfig']['trainingServicePlatforms']:
for index in range(len(experiment_config['machineList'])):
if experiment_config['machineList'][index].get('port') is None:
experiment_config['machineList'][index]['port'] = 22
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import argparse
import logging
import os
import pkg_resources
from colorama import init
......@@ -32,6 +33,8 @@ def nni_info(*args):
print('please run "nnictl {positional argument} --help" to see nnictl guidance')
def parse_args():
logging.getLogger().setLevel(logging.ERROR)
'''Definite the arguments users need to follow and input'''
parser = argparse.ArgumentParser(prog='nnictl', description='use nnictl command to control nni experiments')
parser.add_argument('--version', '-v', action='store_true')
......@@ -243,12 +246,9 @@ def parse_args():
def show_messsage_for_nnictl_package(args):
print_error('nnictl package command is replaced by nnictl algo, please run nnictl algo -h to show the usage')
parser_package_subparsers = subparsers.add_parser('package', help='control nni tuner and assessor packages').add_subparsers()
parser_package_subparsers.add_parser('install', help='install packages').set_defaults(func=show_messsage_for_nnictl_package)
parser_package_subparsers.add_parser('uninstall', help='uninstall packages').set_defaults(func=show_messsage_for_nnictl_package)
parser_package_subparsers.add_parser('show', help='show the information of packages').set_defaults(
func=show_messsage_for_nnictl_package)
parser_package_subparsers.add_parser('list', help='list installed packages').set_defaults(func=show_messsage_for_nnictl_package)
parser_package_subparsers = subparsers.add_parser('package', help='this argument is replaced by algo', prefix_chars='\n')
parser_package_subparsers.add_argument('args', nargs=argparse.REMAINDER)
parser_package_subparsers.set_defaults(func=show_messsage_for_nnictl_package)
#parse tensorboard command
parser_tensorboard = subparsers.add_parser('tensorboard', help='manage tensorboard')
......
......@@ -50,11 +50,9 @@ def update_experiment():
for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED':
nni_config = Config(key)
rest_pid = nni_config.get_config('restServerPid')
rest_pid = experiment_dict[key].get('pid')
if not detect_process(rest_pid):
experiment_config.update_experiment(key, 'status', 'STOPPED')
experiment_config.update_experiment(key, 'port', None)
continue
def check_experiment_id(args, update=True):
......@@ -83,10 +81,10 @@ def check_experiment_id(args, update=True):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'],
experiment_dict[key]['port'],
experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'),
experiment_dict[key]['startTime'],
experiment_dict[key]['endTime'])
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1)
elif not running_experiment_list:
......@@ -130,7 +128,7 @@ def parse_ids(args):
return running_experiment_list
if args.port is not None:
for key in running_experiment_list:
if experiment_dict[key]['port'] == args.port:
if experiment_dict[key].get('port') == args.port:
result_list.append(key)
if args.id and result_list and args.id != result_list[0]:
print_error('Experiment id and resful server port not match')
......@@ -143,10 +141,10 @@ def parse_ids(args):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'],
experiment_dict[key]['port'],
experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'),
experiment_dict[key]['startTime'],
experiment_dict[key]['endTime'])
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1)
else:
......@@ -186,7 +184,7 @@ def get_experiment_port(args):
exit(1)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
return experiment_dict[experiment_id]['port']
return experiment_dict[experiment_id].get('port')
def convert_time_stamp_to_date(content):
'''Convert time stamp to date time format'''
......@@ -202,8 +200,9 @@ def convert_time_stamp_to_date(content):
def check_rest(args):
'''check if restful server is running'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port)
if running:
print_normal('Restful server is running...')
......@@ -220,18 +219,19 @@ def stop_experiment(args):
if experiment_id_list:
for experiment_id in experiment_id_list:
print_normal('Stopping experiment %s' % experiment_id)
nni_config = Config(experiment_id)
rest_pid = nni_config.get_config('restServerPid')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_pid = experiment_dict.get(experiment_id).get('pid')
if rest_pid:
kill_command(rest_pid)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
tensorboard_pid_list = experiment_dict.get(experiment_id).get('tensorboardPidList')
if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list:
try:
kill_command(tensorboard_pid)
except Exception as exception:
print_error(exception)
nni_config.set_config('tensorboardPidList', [])
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop experiment success.')
def trial_ls(args):
......@@ -250,9 +250,10 @@ def trial_ls(args):
if args.head and args.tail:
print_error('Head and tail cannot be set at the same time.')
return
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -281,9 +282,10 @@ def trial_ls(args):
def trial_kill(args):
'''List trial'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -312,9 +314,10 @@ def trial_codegen(args):
def list_experiment(args):
'''Get experiment information'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -333,8 +336,9 @@ def list_experiment(args):
def experiment_status(args):
'''Show the status of experiment'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
result, response = check_rest_server_quick(rest_port)
if not result:
print_normal('Restful server is not running...')
......@@ -620,12 +624,12 @@ def platform_clean(args):
break
if platform == 'remote':
machine_list = config_content.get('machineList')
remote_clean(machine_list, None)
remote_clean(machine_list)
elif platform == 'pai':
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, None)
hdfs_clean(host, user_name, output_dir)
print_normal('Done.')
def experiment_list(args):
......@@ -651,7 +655,7 @@ def experiment_list(args):
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'],
experiment_dict[key]['port'],
experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
......@@ -752,9 +756,10 @@ def export_trials_data(args):
groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data']))
return groupby
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
......
......@@ -70,7 +70,7 @@ def format_tensorboard_log_path(path_list):
new_path_list.append('name%d:%s' % (index + 1, value))
return ','.join(new_path_list)
def start_tensorboard_process(args, nni_config, path_list, temp_nni_path):
def start_tensorboard_process(args, experiment_id, path_list, temp_nni_path):
'''call cmds to start tensorboard process in local machine'''
if detect_port(args.port):
print_error('Port %s is used by another process, please reset port!' % str(args.port))
......@@ -83,20 +83,19 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path):
url_list = get_local_urls(args.port)
print_green('Start tensorboard success!')
print_normal('Tensorboard urls: ' + ' '.join(url_list))
tensorboard_process_pid_list = nni_config.get_config('tensorboardPidList')
experiment_config = Experiments()
tensorboard_process_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_process_pid_list is None:
tensorboard_process_pid_list = [tensorboard_process.pid]
else:
tensorboard_process_pid_list.append(tensorboard_process.pid)
nni_config.set_config('tensorboardPidList', tensorboard_process_pid_list)
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', tensorboard_process_pid_list)
def stop_tensorboard(args):
'''stop tensorboard'''
experiment_id = check_experiment_id(args)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
nni_config = Config(experiment_id)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
tensorboard_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list:
try:
......@@ -104,7 +103,7 @@ def stop_tensorboard(args):
call(cmds)
except Exception as exception:
print_error(exception)
nni_config.set_config('tensorboardPidList', [])
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop tensorboard success!')
else:
print_error('No tensorboard configuration!')
......@@ -164,4 +163,4 @@ def start_tensorboard(args):
os.makedirs(temp_nni_path, exist_ok=True)
path_list = get_path_list(args, nni_config, trial_content, temp_nni_path)
start_tensorboard_process(args, nni_config, path_list, temp_nni_path)
\ No newline at end of file
start_tensorboard_process(args, experiment_id, path_list, temp_nni_path)
......@@ -5,7 +5,7 @@ import json
import os
from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url, import_data_url
from .config_utils import Config
from .config_utils import Config, Experiments
from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import get_experiment_port, get_config_filename, detect_process
from .launcher_utils import parse_time
......@@ -58,8 +58,9 @@ def get_query_type(key):
def update_experiment_profile(args, key, value):
'''call restful server to update experiment profile'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
......
......@@ -4,27 +4,26 @@
jobs:
- job: ubuntu_latest
pool:
# FIXME: In ubuntu-20.04 Python interpreter crashed during SMAC UT
vmImage: ubuntu-18.04
vmImage: ubuntu-latest
# This platform tests lint and doc first.
steps:
- task: UsePythonVersion@0
inputs:
versionSpec: 3.6
versionSpec: 3.8
displayName: Configure Python version
- script: |
set -e
python3 -m pip install --upgrade pip setuptools
python3 -m pip install pytest coverage
python3 -m pip install pylint flake8
python -m pip install --upgrade pip setuptools
python -m pip install pytest coverage
python -m pip install pylint flake8
echo "##vso[task.setvariable variable=PATH]${HOME}/.local/bin:${PATH}"
displayName: Install Python tools
- script: |
python3 setup.py develop
python setup.py develop
displayName: Install NNI
- script: |
......@@ -35,24 +34,28 @@ jobs:
yarn eslint
displayName: ESLint
# FIXME: temporarily fixed to pytorch 1.6 as 1.7 won't work with compression
- script: |
set -e
sudo apt-get install -y pandoc
python3 -m pip install --upgrade pygments
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade tensorflow
python3 -m pip install --upgrade gym onnx peewee thop graphviz
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
sudo apt-get install swig -y
python3 -m pip install -e .[SMAC,BOHB]
python -m pip install --upgrade pygments
python -m pip install "torch==1.6.0+cpu" "torchvision==0.7.0+cpu" -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install tensorflow
python -m pip install gym onnx peewee thop graphviz
python -m pip install sphinx==3.3.1 sphinx-argparse==0.2.5 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 nbsphinx
sudo apt-get remove swig -y
sudo apt-get install swig3.0 -y
sudo ln -s /usr/bin/swig3.0 /usr/bin/swig
python -m pip install -e .[SMAC,BOHB]
displayName: Install extra dependencies
- script: |
set -e
python3 -m pylint --rcfile pylintrc nni
python3 -m flake8 nni --count --select=E9,F63,F72,F82 --show-source --statistics
python -m pylint --rcfile pylintrc nni
python -m flake8 nni --count --select=E9,F63,F72,F82 --show-source --statistics
EXCLUDES=examples/trials/mnist-nas/*/mnist*.py,examples/trials/nas_cifar10/src/cifar10/general_child.py
python3 -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
python -m flake8 examples --count --exclude=$EXCLUDES --select=E9,F63,F72,F82 --show-source --statistics
displayName: pylint and flake8
- script: |
......@@ -61,10 +64,14 @@ jobs:
displayName: Check Sphinx documentation
- script: |
set -e
cd test
python3 -m pytest ut --ignore=ut/sdk/test_pruners.py --ignore=ut/sdk/test_compressor_tf.py
python3 -m pytest ut/sdk/test_pruners.py
python3 -m pytest ut/sdk/test_compressor_tf.py
python -m pytest ut --ignore=ut/sdk/test_pruners.py \
--ignore=ut/sdk/test_compressor_tf.py \
--ignore=ut/sdk/test_compressor_torch.py
python -m pytest ut/sdk/test_pruners.py
python -m pytest ut/sdk/test_compressor_tf.py
python -m pytest ut/sdk/test_compressor_torch.py
displayName: Python unit test
- script: |
......@@ -77,7 +84,7 @@ jobs:
- script: |
cd test
python3 nni_test/nnitest/run_tests.py --config config/pr_tests.yml
python nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName: Simple integration test
......
trigger: none
pr: none
schedules:
- cron: 0 16 * * *
branches:
include: [ master ]
jobs:
- job: adl
pool: NNI CI KUBE CLI
timeoutInMinutes: 120
steps:
- script: |
export NNI_RELEASE=999.$(date -u +%Y%m%d%H%M%S)
echo "##vso[task.setvariable variable=PATH]${PATH}:${HOME}/.local/bin"
echo "##vso[task.setvariable variable=NNI_RELEASE]${NNI_RELEASE}"
echo "Working directory: ${PWD}"
echo "NNI version: ${NNI_RELEASE}"
echo "Build docker image: $(build_docker_image)"
python3 -m pip install --upgrade pip setuptools
displayName: Prepare
- script: |
set -e
python3 setup.py build_ts
python3 setup.py bdist_wheel -p manylinux1_x86_64
python3 -m pip install dist/nni-${NNI_RELEASE}-py3-none-manylinux1_x86_64.whl[SMAC,BOHB]
displayName: Build and install NNI
- script: |
set -e
cd examples/tuners/customized_tuner
python3 setup.py develop --user
nnictl algo register --meta meta_file.yml
displayName: Install customized tuner
- script: |
set -e
docker login -u nnidev -p $(docker_hub_password)
sed -i '$a RUN python3 -m pip install adaptdl tensorboard' Dockerfile
sed -i '$a COPY examples /examples' Dockerfile
sed -i '$a COPY test /test' Dockerfile
echo '## Build docker image ##'
docker build --build-arg NNI_RELEASE=${NNI_RELEASE} -t nnidev/nni-nightly .
echo '## Upload docker image ##'
docker push nnidev/nni-nightly
condition: eq(variables['build_docker_image'], 'true')
displayName: Build and upload docker image
- script: |
set -e
cd test
python3 nni_test/nnitest/generate_ts_config.py \
--ts adl \
--nni_docker_image nnidev/nni-nightly \
--checkpoint_storage_class $(checkpoint_storage_class) \
--checkpoint_storage_size $(checkpoint_storage_size) \
--nni_manager_ip $(nni_manager_ip)
python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts adl
displayName: Integration test
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
trigger: none
pr: none
jobs:
- job: validate_version_number
pool:
......@@ -13,9 +16,11 @@ jobs:
displayName: Configure Python version
- script: |
echo $(build_type)
echo $(NNI_RELEASE)
export BRANCH_TAG=`git describe --tags --abbrev=0`
echo $BRANCH_TAG
if [[ $BRANCH_TAG = v$(NNI_RELEASE) && $(NNI_RELEASE) =~ ^v[0-9](.[0-9])+$ ]]; then
if [[ $BRANCH_TAG == v$(NNI_RELEASE) && $(NNI_RELEASE) =~ ^[0-9](.[0-9])+$ ]]; then
echo 'Build version match branch tag'
else
echo 'Build version does not match branch tag'
......@@ -25,15 +30,16 @@ jobs:
displayName: Validate release version number and branch tag
- script: |
echo $(build_type)
echo $(NNI_RELEASE)
if [[ $(NNI_RELEASE) =~ ^[0-9](.[0-9])+a[0-9]$ ]]; then
if [[ $(NNI_RELEASE) =~ ^[0-9](.[0-9])+(a|b|rc)[0-9]$ ]]; then
echo 'Valid prerelease version $(NNI_RELEASE)'
echo `git describe --tags --abbrev=0`
else
echo 'Invalid build version $(NNI_RELEASE)'
exit 1
fi
condition: ne( variables['build_type'], 'rerelease' )
condition: ne( variables['build_type'], 'release' )
displayName: Validate prerelease version number
- job: linux
......@@ -49,22 +55,22 @@ jobs:
displayName: Configure Python version
- script: |
python -m pip install --upgrade pip setuptools twine
python -m pip install --upgrade pip setuptools wheel twine
python test/vso_tools/build_wheel.py $(NNI_RELEASE)
displayName: Build wheel
if [ $(build_type) = 'release' ]
echo 'uploading release package to pypi...'
- script: |
if [[ $(build_type) == 'release' || $(build_type) == 'rc' ]]; then
echo 'uploading to pypi...'
python -m twine upload -u nni -p $(pypi_password) dist/*
then
else
echo 'uploading prerelease package to testpypi...'
echo 'uploading to testpypi...'
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
fi
displayName: Build and upload wheel
displayName: Upload wheel
- script: |
if [ $(build_type) = 'release' ]
then
if [[ $(build_type) == 'release' || $(build_type) == 'rc' ]]; then
docker login -u msranni -p $(docker_hub_password)
export IMAGE_NAME=msranni/nni
else
......@@ -74,9 +80,11 @@ jobs:
echo "## Building ${IMAGE_NAME}:$(NNI_RELEASE) ##"
docker build --build-arg NNI_RELEASE=$(NNI_RELEASE) -t ${IMAGE_NAME} .
docker tag ${IMAGE_NAME} ${IMAGE_NAME}:$(NNI_RELEASE)
docker push ${IMAGE_NAME}
docker push ${IMAGE_NAME}:$(NNI_RELEASE)
docker tag ${IMAGE_NAME} ${IMAGE_NAME}:v$(NNI_RELEASE)
docker push ${IMAGE_NAME}:v$(NNI_RELEASE)
if [[ $(build_type) != 'rc' ]]; then
docker push ${IMAGE_NAME}
fi
displayName: Build and upload docker image
- job: macos
......@@ -92,18 +100,19 @@ jobs:
displayName: Configure Python version
- script: |
python -m pip install --upgrade pip setuptools twine
python -m pip install --upgrade pip setuptools wheel twine
python test/vso_tools/build_wheel.py $(NNI_RELEASE)
displayName: Build wheel
if [ $(build_type) = 'release' ]
- script: |
if [[ $(build_type) == 'release' || $(build_type) == 'rc' ]]; then
echo '## uploading to pypi ##'
python -m twine upload -u nni -p $(pypi_password) dist/*
then
else
echo '## uploading to testpypi ##'
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
fi
displayName: Build and upload wheel
displayName: Upload wheel
- job: windows
dependsOn: validate_version_number
......@@ -118,15 +127,16 @@ jobs:
displayName: Configure Python version
- powershell: |
python -m pip install --upgrade pip setuptools twine
python -m pip install --upgrade pip setuptools wheel twine
python test/vso_tools/build_wheel.py $(NNI_RELEASE)
displayName: Build wheel
if($env:BUILD_TYPE -eq 'release'){
- powershell: |
if ($env:BUILD_TYPE -eq 'release' -Or $env:BUILD_TYPE -eq 'rc') {
Write-Host '## uploading to pypi ##'
python -m twine upload -u nni -p $(pypi_password) dist/*
}
else{
} else {
Write-Host '## uploading to testpypi ##'
python -m twine upload -u nni -p $(pypi_password) --repository-url https://test.pypi.org/legacy/ dist/*
}
displayName: Build and upload wheel
displayName: Upload wheel
......@@ -69,7 +69,6 @@ dependencies = [
'PythonWebHDFS',
'colorama',
'scikit-learn>=0.23.2',
'pkginfo',
'websockets',
'filelock',
'prettytable',
......@@ -112,11 +111,8 @@ def _setup():
python_requires = '>=3.6',
install_requires = dependencies,
extras_require = {
'SMAC': [
'ConfigSpaceNNI @ git+https://github.com/QuanluZhang/ConfigSpace.git',
'smac @ git+https://github.com/QuanluZhang/SMAC3.git'
],
'BOHB': ['ConfigSpace==0.4.7', 'statsmodels==0.10.0'],
'SMAC': ['ConfigSpaceNNI', 'smac4nni'],
'BOHB': ['ConfigSpace==0.4.7', 'statsmodels==0.12.0'],
'PPOTuner': ['enum34', 'gym']
},
setup_requires = ['requests'],
......@@ -189,6 +185,7 @@ class Build(build):
sys.exit('Please set environment variable "NNI_RELEASE=<release_version>"')
if os.path.islink('nni_node/main.js'):
sys.exit('A development build already exists. Please uninstall NNI and run "python3 setup.py clean --all".')
open('nni/version.py', 'w').write(f"__version__ = '{release}'")
super().run()
class Develop(develop):
......@@ -212,6 +209,7 @@ class Develop(develop):
super().finalize_options()
def run(self):
open('nni/version.py', 'w').write("__version__ = '999.dev0'")
if not self.skip_ts:
setup_ts.build(release=None)
super().run()
......
......@@ -196,6 +196,7 @@ def copy_nni_node(version):
package_json['version'] = version
json.dump(package_json, open('nni_node/package.json', 'w'), indent=2)
# reinstall without development dependencies
_yarn('ts/nni_manager', '--prod', '--cwd', str(Path('nni_node').resolve()))
shutil.copytree('ts/webui/build', 'nni_node/static')
......@@ -226,9 +227,9 @@ def _symlink(target_file, link_location):
def _print(*args):
if sys.platform == 'win32':
print(*args)
print(*args, flush=True)
else:
print('\033[1;36m#', *args, '\033[0m')
print('\033[1;36m#', *args, '\033[0m', flush=True)
generated_files = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment