"src/vscode:/vscode.git/clone" did not exist on "16fd8a21f250503115197812258490db75148e60"
Unverified Commit a5764016 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Install builtin tuners (#2439)

parent 0f7f9460
......@@ -12,6 +12,9 @@ import sys
import json_tricks
import numpy as np
from schema import Schema, Optional
from nni import ClassArgsValidator
from nni.common import multi_phase_enabled
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.protocol import CommandType, send
......@@ -249,6 +252,13 @@ class Bracket():
self.num_configs_to_run.append(len(hyper_configs))
self.increase_i()
class HyperbandClassArgsValidator(ClassArgsValidator):
def validate_class_args(self, **kwargs):
Schema({
'optimize_mode': self.choices('optimize_mode', 'maximize', 'minimize'),
Optional('R'): int,
Optional('eta'): int
}).validate(kwargs)
class Hyperband(MsgDispatcherBase):
"""Hyperband inherit from MsgDispatcherBase rather than Tuner, because it integrates both tuner's functions and assessor's functions.
......
......@@ -10,6 +10,8 @@ import logging
import hyperopt as hp
import numpy as np
from schema import Optional, Schema
from nni import ClassArgsValidator
from nni.tuner import Tuner
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index
......@@ -178,6 +180,13 @@ def _add_index(in_x, parameter):
return parameter
return None # note: this is not written by original author, feel free to modify if you think it's incorrect
class HyperoptClassArgsValidator(ClassArgsValidator):
def validate_class_args(self, **kwargs):
Schema({
Optional('optimize_mode'): self.choices('optimize_mode', 'maximize', 'minimize'),
Optional('parallel_optimize'): bool,
Optional('constant_liar_type'): self.choices('constant_liar_type', 'min', 'max', 'mean')
}).validate(kwargs)
class HyperoptTuner(Tuner):
"""
......
......@@ -2,11 +2,21 @@
# Licensed under the MIT license.
import logging
from schema import Schema, Optional
from nni import ClassArgsValidator
from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
logger = logging.getLogger('medianstop_Assessor')
class MedianstopClassArgsValidator(ClassArgsValidator):
def validate_class_args(self, **kwargs):
Schema({
Optional('optimize_mode'): self.choices('optimize_mode', 'maximize', 'minimize'),
Optional('start_step'): self.range('start_step', int, 0, 9999),
}).validate(kwargs)
class MedianstopAssessor(Assessor):
"""MedianstopAssessor is The median stopping rule stops a pending trial X at step S
if the trial’s best objective value by step S is strictly worse than the median value
......
......@@ -12,7 +12,9 @@ import statistics
import warnings
from multiprocessing.dummy import Pool as ThreadPool
import numpy as np
from schema import Schema, Optional
from nni import ClassArgsValidator
import nni.metis_tuner.lib_constraint_summation as lib_constraint_summation
import nni.metis_tuner.lib_data as lib_data
import nni.metis_tuner.Regression_GMM.CreateModel as gmm_create_model
......@@ -31,6 +33,15 @@ CONSTRAINT_LOWERBOUND = None
CONSTRAINT_UPPERBOUND = None
CONSTRAINT_PARAMS_IDX = []
class MetisClassArgsValidator(ClassArgsValidator):
def validate_class_args(self, **kwargs):
Schema({
Optional('optimize_mode'): self.choices('optimize_mode', 'maximize', 'minimize'),
Optional('no_resampling'): bool,
Optional('no_candidates'): bool,
Optional('selection_num_starting_points'): int,
Optional('cold_start_num'): int,
}).validate(kwargs)
class MetisTuner(Tuner):
"""
......
......@@ -7,17 +7,26 @@ networkmorphsim_tuner.py
import logging
import os
from schema import Optional, Schema
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
from nni.networkmorphism_tuner.bayesian import BayesianOptimizer
from nni.networkmorphism_tuner.nn import CnnGenerator, MlpGenerator
from nni.networkmorphism_tuner.utils import Constant
from nni.networkmorphism_tuner.graph import graph_to_json, json_to_graph
from nni import ClassArgsValidator
logger = logging.getLogger("NetworkMorphism_AutoML")
class NetworkMorphismClassArgsValidator(ClassArgsValidator):
def validate_class_args(self, **kwargs):
Schema({
Optional('optimize_mode'): self.choices('optimize_mode', 'maximize', 'minimize'),
Optional('task'): self.choices('task', 'cv', 'nlp', 'common'),
Optional('input_width'): int,
Optional('input_channel'): int,
Optional('n_output_node'): int
}).validate(kwargs)
class NetworkMorphismTuner(Tuner):
"""
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import site
import sys
from collections import defaultdict
from pathlib import Path
import importlib
import ruamel.yaml as yaml
from .constants import BuiltinAlgorithms
ALGO_TYPES = ['tuners', 'assessors', 'advisors']
def get_all_builtin_names(algo_type):
"""Get all valid builtin names, including:
1. BuiltinAlgorithms which is pre-installed.
2. User installed packages in <nni_installation_path>/config/installed_packages.yml
Parameters
----------
algo_type: str
can be one of 'tuners', 'assessors' or 'advisors'
Returns: list of string
-------
All builtin names of specified type, for example, if algo_type is 'tuners', returns
all builtin tuner names.
"""
assert algo_type in ALGO_TYPES
merged_dict = _get_merged_builtin_dict()
builtin_names = [x['name'] for x in merged_dict[algo_type]]
return builtin_names
def get_not_installable_builtin_names(algo_type=None):
"""Get builtin names in BuiltinAlgorithms which do not need to be installed
and can be used once NNI is installed.
Parameters
----------
algo_type: str | None
can be one of 'tuners', 'assessors', 'advisors' or None
Returns: list of string
-------
All builtin names of specified type, for example, if algo_type is 'tuners', returns
all builtin tuner names.
If algo_type is None, returns all builtin names of all types.
"""
if algo_type is None:
meta = BuiltinAlgorithms
else:
assert algo_type in ALGO_TYPES
meta = {
algo_type: BuiltinAlgorithms[algo_type]
}
names = []
for t in ALGO_TYPES:
if t in meta:
names.extend([x['name'] for x in meta[t]])
return names
def get_builtin_algo_meta(algo_type=None, builtin_name=None):
""" Get meta information of builtin algorithms from:
1. Pre-installed BuiltinAlgorithms
2. User installed packages in <nni_installation_path>/config/installed_packages.yml
Parameters
----------
algo_type: str | None
can be one of 'tuners', 'assessors', 'advisors' or None
builtin_name: str | None
builtin name.
Returns: dict | list of dict | None
-------
If builtin_name is specified, returns meta information of speicified builtin
alogorithms, for example:
{
'name': 'Random',
'class_name': 'nni.hyperopt_tuner.hyperopt_tuner.HyperoptTuner',
'class_args': {
'algorithm_name': 'random_search'
},
'accept_class_args': False,
'class_args_validator': 'nni.hyperopt_tuner.hyperopt_tuner.HyperoptClassArgsValidator'
}
If builtin_name is None, returns multiple meta information in a list.
"""
merged_dict = _get_merged_builtin_dict()
if algo_type is None and builtin_name is None:
return merged_dict
if algo_type:
assert algo_type in ALGO_TYPES
metas = merged_dict[algo_type]
else:
metas = merged_dict['tuners'] + merged_dict['assessors'] + merged_dict['advisors']
if builtin_name:
for m in metas:
if m['name'] == builtin_name:
return m
else:
return metas
return None
def get_installed_package_meta(algo_type, builtin_name):
""" Get meta information of user installed algorithms from:
<nni_installation_path>/config/installed_packages.yml
Parameters
----------
algo_type: str | None
can be one of 'tuners', 'assessors', 'advisors' or None
builtin_name: str
builtin name.
Returns: dict | None
-------
Returns meta information of speicified builtin alogorithms, for example:
{
'class_args_validator': 'nni.smac_tuner.smac_tuner.SMACClassArgsValidator',
'class_name': 'nni.smac_tuner.smac_tuner.SMACTuner',
'name': 'SMAC'
}
"""
assert builtin_name is not None
if algo_type:
assert algo_type in ALGO_TYPES
config = read_installed_package_meta()
candidates = []
if algo_type:
candidates = config[algo_type]
else:
for algo_type in ALGO_TYPES:
candidates.extend(config[algo_type])
for meta in candidates:
if meta['name'] == builtin_name:
return meta
return None
def _parse_full_class_name(full_class_name):
if not full_class_name:
return None, None
parts = full_class_name.split('.')
module_name, class_name = '.'.join(parts[:-1]), parts[-1]
return module_name, class_name
def get_builtin_module_class_name(algo_type, builtin_name):
"""Get module name and class name of all builtin algorithms
Parameters
----------
algo_type: str
can be one of 'tuners', 'assessors', 'advisors'
builtin_name: str
builtin name.
Returns: tuple
-------
tuple of (module name, class name)
"""
assert algo_type in ALGO_TYPES
assert builtin_name is not None
meta = get_builtin_algo_meta(algo_type, builtin_name)
if not meta:
return None, None
return _parse_full_class_name(meta['class_name'])
def create_validator_instance(algo_type, builtin_name):
"""Create instance of validator class
Parameters
----------
algo_type: str
can be one of 'tuners', 'assessors', 'advisors'
builtin_name: str
builtin name.
Returns: object | None
-------
Returns validator class instance.
If specified validator class does not exist, returns None.
"""
assert algo_type in ALGO_TYPES
assert builtin_name is not None
meta = get_builtin_algo_meta(algo_type, builtin_name)
if not meta or 'class_args_validator' not in meta:
return None
module_name, class_name = _parse_full_class_name(meta['class_args_validator'])
class_module = importlib.import_module(module_name)
class_constructor = getattr(class_module, class_name)
return class_constructor()
def create_builtin_class_instance(builtin_name, input_class_args, algo_type):
"""Create instance of builtin algorithms
Parameters
----------
builtin_name: str
builtin name.
input_class_args: dict
kwargs for builtin class constructor
algo_type: str
can be one of 'tuners', 'assessors', 'advisors'
Returns: object
-------
Returns builtin class instance.
"""
assert algo_type in ALGO_TYPES
if builtin_name not in get_all_builtin_names(algo_type):
raise RuntimeError('Builtin name is not found: {}'.format(builtin_name))
def parse_algo_meta(algo_meta, input_class_args):
"""
1. parse class_name field in meta data into module name and class name,
for example:
parse class_name 'nni.hyperopt_tuner.hyperopt_tuner.HyperoptTuner' in meta data into:
module name: nni.hyperopt_tuner.hyperopt_tuner
class name: HyperoptTuner
2. merge user specified class args together with builtin class args.
"""
assert algo_meta
module_name, class_name = _parse_full_class_name(algo_meta['class_name'])
class_args = {}
if 'class_args' in algo_meta:
class_args = algo_meta['class_args']
if input_class_args is not None:
class_args.update(input_class_args)
return module_name, class_name, class_args
algo_meta = get_builtin_algo_meta(algo_type, builtin_name)
module_name, class_name, class_args = parse_algo_meta(algo_meta, input_class_args)
if importlib.util.find_spec(module_name) is None:
raise RuntimeError('Builtin module can not be loaded: {}'.format(module_name))
class_module = importlib.import_module(module_name)
class_constructor = getattr(class_module, class_name)
instance = class_constructor(**class_args)
return instance
def create_customized_class_instance(class_params):
"""Create instance of customized algorithms
Parameters
----------
class_params: dict
class_params should contains following keys:
codeDir: code directory
classFileName: python file name of the class
className: class name
classArgs (optional): kwargs pass to class constructor
Returns: object
-------
Returns customized class instance.
"""
code_dir = class_params.get('codeDir')
class_filename = class_params.get('classFileName')
class_name = class_params.get('className')
class_args = class_params.get('classArgs')
if not os.path.isfile(os.path.join(code_dir, class_filename)):
raise ValueError('Class file not found: {}'.format(
os.path.join(code_dir, class_filename)))
sys.path.append(code_dir)
module_name = os.path.splitext(class_filename)[0]
class_module = importlib.import_module(module_name)
class_constructor = getattr(class_module, class_name)
if class_args is None:
class_args = {}
instance = class_constructor(**class_args)
return instance
def get_python_dir(sitepackages_path):
if sys.platform == "win32":
return str(Path(sitepackages_path))
else:
return str(Path(sitepackages_path).parents[2])
def get_nni_installation_parent_dir():
''' Find nni installation parent directory
'''
def try_installation_path_sequentially(*sitepackages):
'''Try different installation path sequentially util nni is found.
Return None if nothing is found
'''
def _generate_installation_path(sitepackages_path):
python_dir = get_python_dir(sitepackages_path)
entry_file = os.path.join(python_dir, 'nni', 'main.js')
if os.path.isfile(entry_file):
return python_dir
return None
for sitepackage in sitepackages:
python_dir = _generate_installation_path(sitepackage)
if python_dir:
return python_dir
return None
if os.getenv('VIRTUAL_ENV'):
# if 'virtualenv' package is used, `site` has not attr getsitepackages, so we will instead use VIRTUAL_ENV
# Note that conda venv will not have VIRTUAL_ENV
python_dir = os.getenv('VIRTUAL_ENV')
else:
python_sitepackage = site.getsitepackages()[0]
# If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given
# that nni exists there
if python_sitepackage.startswith('/usr') or python_sitepackage.startswith('/Library'):
python_dir = try_installation_path_sequentially(site.getusersitepackages(), site.getsitepackages()[0])
else:
python_dir = try_installation_path_sequentially(site.getsitepackages()[0], site.getusersitepackages())
return python_dir
def get_nni_installation_path():
''' Find nni installation directory
'''
parent_dir = get_nni_installation_parent_dir()
if parent_dir:
entry_file = os.path.join(parent_dir, 'nni', 'main.js')
if os.path.isfile(entry_file):
return os.path.join(parent_dir, 'nni')
return None
def get_nni_config_dir():
return os.path.join(get_nni_installation_path(), 'config')
def get_package_config_path():
config_dir = get_nni_config_dir()
if not os.path.exists(config_dir):
os.makedirs(config_dir, exist_ok=True)
return os.path.join(config_dir, 'installed_packages.yml')
def read_installed_package_meta():
config_file = get_package_config_path()
if os.path.exists(config_file):
with open(config_file, 'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
else:
config = defaultdict(list)
for t in ALGO_TYPES:
if t not in config:
config[t] = []
return config
def write_package_meta(config):
config_file = get_package_config_path()
with open(config_file, 'w') as f:
f.write(yaml.dump(dict(config), default_flow_style=False))
def _get_merged_builtin_dict():
def merge_meta_dict(d1, d2):
res = defaultdict(list)
for t in ALGO_TYPES:
res[t] = d1[t] + d2[t]
return res
return merge_meta_dict(BuiltinAlgorithms, read_installed_package_meta())
......@@ -6,8 +6,10 @@ import logging
import os
import random
import numpy as np
from schema import Schema, Optional
import nni
from nni import ClassArgsValidator
import nni.parameter_expressions
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2parameter, json2space
......@@ -157,6 +159,15 @@ class TrialInfo:
def clean_id(self):
self.parameter_id = None
class PBTClassArgsValidator(ClassArgsValidator):
def validate_class_args(self, **kwargs):
Schema({
'optimize_mode': self.choices('optimize_mode', 'maximize', 'minimize'),
Optional('all_checkpoint_dir'): str,
Optional('population_size'): self.range('population_size', int, 0, 99999),
Optional('factors'): float,
Optional('fraction'): float,
}).validate(kwargs)
class PBTTuner(Tuner):
def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factor=0.2,
......
......@@ -10,8 +10,10 @@ import copy
import logging
import numpy as np
from gym import spaces
from schema import Schema, Optional
import nni
from nni import ClassArgsValidator
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
......@@ -285,6 +287,21 @@ class PPOModel:
mbstates = states[mbenvinds]
self.model.train(lrnow, cliprangenow, *slices, mbstates)
class PPOClassArgsValidator(ClassArgsValidator):
def validate_class_args(self, **kwargs):
Schema({
'optimize_mode': self.choices('optimize_mode', 'maximize', 'minimize'),
Optional('trials_per_update'): self.range('trials_per_update', int, 0, 99999),
Optional('epochs_per_update'): self.range('epochs_per_update', int, 0, 99999),
Optional('minibatch_size'): self.range('minibatch_size', int, 0, 99999),
Optional('ent_coef'): float,
Optional('lr'): float,
Optional('vf_coef'): float,
Optional('max_grad_norm'): float,
Optional('gamma'): float,
Optional('lam'): float,
Optional('cliprange'): float,
}).validate(kwargs)
class PPOTuner(Tuner):
"""
......
......@@ -9,6 +9,7 @@ import logging
import sys
import numpy as np
from schema import Schema, Optional
from smac.facade.epils_facade import EPILS
from smac.facade.roar_facade import ROAR
......@@ -19,6 +20,7 @@ from smac.utils.io.cmd_reader import CMDReader
from ConfigSpaceNNI import Configuration
import nni
from nni import ClassArgsValidator
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
......@@ -26,6 +28,13 @@ from .convert_ss_to_scenario import generate_scenario
logger = logging.getLogger('smac_AutoML')
class SMACClassArgsValidator(ClassArgsValidator):
def validate_class_args(self, **kwargs):
Schema({
'optimize_mode': self.choices('optimize_mode', 'maximize', 'minimize'),
Optional('config_dedup'): bool
}).validate(kwargs)
class SMACTuner(Tuner):
"""
This is a wrapper of [SMAC](https://github.com/automl/SMAC3) following NNI tuner interface.
......
......@@ -6,6 +6,7 @@ import copy
import functools
from enum import Enum, unique
import json_tricks
from schema import And
from . import parameter_expressions
from .common import init_logger
......@@ -217,7 +218,6 @@ def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeTyp
y = copy.deepcopy(x)
return y
def merge_parameter(base_params, override_params):
"""
Update the parameters in ``base_params`` with ``override_params``.
......@@ -256,3 +256,64 @@ def merge_parameter(base_params, override_params):
(k, type(getattr(base_params, k)), type(v)))
setattr(base_params, k, v)
return base_params
class ClassArgsValidator(object):
"""
NNI tuners/assessors/adivisors accept a `classArgs` parameter in experiment configuration file.
This ClassArgsValidator interface is used to validate the classArgs section in exeperiment
configuration file.
"""
def validate_class_args(self, **kwargs):
"""
Validate the classArgs configuration in experiment configuration file.
Parameters
----------
kwargs: dict
kwargs passed to tuner/assessor/advisor constructor
Raises:
Raise an execption if the kwargs is invalid.
"""
pass
def choices(self, key, *args):
"""
Utility method to create a scheme to check whether the `key` is one of the `args`.
Parameters:
----------
key: str
key name of the data to be validated
args: list of str
list of the choices
Returns: Schema
--------
A scheme to check whether the `key` is one of the `args`.
"""
return And(lambda n: n in args, error='%s should be in [%s]!' % (key, str(args)))
def range(self, key, keyType, start, end):
"""
Utility method to create a schema to check whether the `key` is in the range of [start, end].
Parameters:
----------
key: str
key name of the data to be validated
keyType: type
python data type, such as int, float
start: type is specified by keyType
start of the range
end: type is specified by keyType
end of the range
Returns: Schema
--------
A scheme to check whether the `key` is in the range of [start, end].
"""
return And(
And(keyType, error='%s should be %s type!' % (key, keyType.__name__)),
And(lambda n: start <= n <= end, error='%s should be in range of (%s, %s)!' % (key, start, end))
)
......@@ -25,3 +25,9 @@ cd ${CWD}/../src/nni_manager
echo ""
echo "===========================Testing: nni_manager==========================="
npm run test
## ------Run nnictl unit test------
echo ""
echo "===========================Testing: nnictl==========================="
cd ${CWD}/../tools/nni_cmd/
python3 -m unittest discover -v tests
......@@ -62,7 +62,7 @@ def install_requirements_command(requirements_path):
requirements_path: str
Path to the directory that contains `requirements.txt`.
"""
call(_get_pip_install() + ["-r", os.path.join(requirements_path, "requirements.txt")], shell=False)
return call(_get_pip_install() + ["-r", requirements_path], shell=False)
def _get_pip_install():
......@@ -72,3 +72,11 @@ def _get_pip_install():
(sys.platform != "win32" and os.getuid() != 0): # on unix and not running in root
ret.append("--user") # not in virtualenv or conda
return ret
def call_pip_install(source):
return call(_get_pip_install() + [source])
def call_pip_uninstall(module_name):
python = "python" if sys.platform == "win32" else "python3"
cmd = [python, "-m", "pip", "uninstall", module_name]
return call(cmd)
......@@ -2,14 +2,14 @@
# Licensed under the MIT license.
import os
import site
import sys
import json
import socket
from pathlib import Path
import ruamel.yaml as yaml
import psutil
from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO, COLOR_RED_FORMAT, COLOR_YELLOW_FORMAT
from colorama import Fore
from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO
def get_yml_content(file_path):
'''Load yaml file content'''
......@@ -34,17 +34,22 @@ def get_json_content(file_path):
print_error(err)
return None
def print_error(content):
def print_error(*content):
'''Print error information to screen'''
print(COLOR_RED_FORMAT % (ERROR_INFO % content))
print(Fore.RED + ERROR_INFO + ' '.join([str(c) for c in content]) + Fore.RESET)
def print_green(*content):
'''Print information to screen in green'''
print(Fore.GREEN + ' '.join([str(c) for c in content]) + Fore.RESET)
def print_normal(content):
def print_normal(*content):
'''Print error information to screen'''
print(NORMAL_INFO % content)
print(NORMAL_INFO, *content)
def print_warning(content):
def print_warning(*content):
'''Print warning information to screen'''
print(COLOR_YELLOW_FORMAT % (WARNING_INFO % content))
print(Fore.YELLOW + WARNING_INFO + ' '.join([str(c) for c in content]) + Fore.RESET)
def detect_process(pid):
'''Detect if a process is alive'''
......@@ -70,12 +75,6 @@ def get_user():
else:
return os.environ['USER']
def get_python_dir(sitepackages_path):
if sys.platform == "win32":
return str(Path(sitepackages_path))
else:
return str(Path(sitepackages_path).parents[2])
def check_tensorboard_version():
try:
import tensorboard
......@@ -84,43 +83,3 @@ def check_tensorboard_version():
print_error('import tensorboard error!')
exit(1)
def get_nni_installation_path():
''' Find nni lib from the following locations in order
Return nni root directory if it exists
'''
def try_installation_path_sequentially(*sitepackages):
'''Try different installation path sequentially util nni is found.
Return None if nothing is found
'''
def _generate_installation_path(sitepackages_path):
python_dir = get_python_dir(sitepackages_path)
entry_file = os.path.join(python_dir, 'nni', 'main.js')
if os.path.isfile(entry_file):
return python_dir
return None
for sitepackage in sitepackages:
python_dir = _generate_installation_path(sitepackage)
if python_dir:
return python_dir
return None
if os.getenv('VIRTUAL_ENV'):
# if 'virtualenv' package is used, `site` has not attr getsitepackages, so we will instead use VIRTUAL_ENV
# Note that conda venv will not have VIRTUAL_ENV
python_dir = os.getenv('VIRTUAL_ENV')
else:
python_sitepackage = site.getsitepackages()[0]
# If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given
# that nni exists there
if python_sitepackage.startswith('/usr') or python_sitepackage.startswith('/Library'):
python_dir = try_installation_path_sequentially(site.getusersitepackages(), site.getsitepackages()[0])
else:
python_dir = try_installation_path_sequentially(site.getsitepackages()[0], site.getusersitepackages())
if python_dir:
entry_file = os.path.join(python_dir, 'nni', 'main.js')
if os.path.isfile(entry_file):
return os.path.join(python_dir, 'nni')
print_error('Fail to find nni under python library')
exit(1)
\ No newline at end of file
This diff is collapsed.
......@@ -6,14 +6,11 @@ from colorama import Fore
NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), '.local', 'nnictl')
ERROR_INFO = 'ERROR: %s'
NORMAL_INFO = 'INFO: %s'
WARNING_INFO = 'WARNING: %s'
ERROR_INFO = 'ERROR: '
NORMAL_INFO = 'INFO: '
WARNING_INFO = 'WARNING: '
DEFAULT_REST_PORT = 8080
REST_TIME_OUT = 20
EXPERIMENT_SUCCESS_INFO = Fore.GREEN + 'Successfully started experiment!\n' + Fore.RESET + \
......@@ -62,10 +59,25 @@ TRIAL_MONITOR_CONTENT = '%-15s %-25s %-25s %-15s'
TRIAL_MONITOR_TAIL = '-------------------------------------------------------------------------------------\n\n\n'
PACKAGE_REQUIREMENTS = {
'SMAC': 'smac_tuner',
'BOHB': 'bohb_advisor',
'PPOTuner': 'ppo_tuner'
INSTALLABLE_PACKAGE_META = {
'SMAC': {
'type': 'tuner',
'class_name': 'nni.smac_tuner.smac_tuner.SMACTuner',
'code_sub_dir': 'smac_tuner',
'class_args_validator': 'nni.smac_tuner.smac_tuner.SMACClassArgsValidator'
},
'BOHB': {
'type': 'advisor',
'class_name': 'nni.bohb_advisor.bohb_advisor.BOHB',
'code_sub_dir': 'bohb_advisor',
'class_args_validator': 'nni.bohb_advisor.bohb_advisor.BOHBClassArgsValidator'
},
'PPOTuner': {
'type': 'tuner',
'class_name': 'nni.ppo_tuner.ppo_tuner.PPOTuner',
'code_sub_dir': 'ppo_tuner',
'class_args_validator': 'nni.ppo_tuner.ppo_tuner.PPOClassArgsValidator'
}
}
TUNERS_SUPPORTING_IMPORT_DATA = {
......@@ -83,14 +95,6 @@ TUNERS_NO_NEED_TO_IMPORT_DATA = {
'Hyperband'
}
COLOR_RED_FORMAT = Fore.RED + '%s'
COLOR_GREEN_FORMAT = Fore.GREEN + '%s'
COLOR_YELLOW_FORMAT = Fore.YELLOW + '%s'
SCHEMA_TYPE_ERROR = '%s should be %s type!'
SCHEMA_RANGE_ERROR = '%s should be in range of %s!'
SCHEMA_PATH_ERROR = '%s path not exist!'
......@@ -10,14 +10,15 @@ import time
import tempfile
from subprocess import Popen, check_call, CalledProcessError, PIPE, STDOUT
from nni_annotation import expand_annotations, generate_search_space
from nni.constants import ModuleName, AdvisorModuleName
from nni.package_utils import get_builtin_module_class_name, get_nni_installation_path
from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_response
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls
from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, \
detect_port, get_user, get_nni_installation_path
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, PACKAGE_REQUIREMENTS
detect_port, get_user
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, INSTALLABLE_PACKAGE_META
from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment
......@@ -52,6 +53,9 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
print_normal('Starting restful server...')
entry_dir = get_nni_installation_path()
if (not entry_dir) or (not os.path.exists(entry_dir)):
print_error('Fail to find nni under python library')
exit(1)
entry_file = os.path.join(entry_dir, 'main.js')
node_command = 'node'
......@@ -390,10 +394,10 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
package_name, module_name = None, None
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
package_name = experiment_config['tuner']['builtinTunerName']
module_name = ModuleName.get(package_name)
module_name, _ = get_builtin_module_class_name('tuners', package_name)
elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'):
package_name = experiment_config['advisor']['builtinAdvisorName']
module_name = AdvisorModuleName.get(package_name)
module_name, _ = get_builtin_module_class_name('advisors', package_name)
if package_name and module_name:
try:
stdout_full_path, stderr_full_path = get_log_path(config_file_name)
......@@ -402,7 +406,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
except CalledProcessError:
print_error('some errors happen when import package %s.' %(package_name))
print_log_content(config_file_name)
if package_name in PACKAGE_REQUIREMENTS:
if package_name in INSTALLABLE_PACKAGE_META:
print_error('If %s is not installed, it should be installed through '\
'\'nnictl package install --name %s\''%(package_name, package_name))
exit(1)
......@@ -502,7 +506,11 @@ def create_experiment(args):
print_error('Please set correct config path!')
exit(1)
experiment_config = get_yml_content(config_path)
validate_all_content(experiment_config, config_path)
try:
validate_all_content(experiment_config, config_path)
except Exception as e:
print_error(e)
exit(1)
nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port)
......
......@@ -2,14 +2,9 @@
# Licensed under the MIT license.
import os
import json
import netifaces
from schema import SchemaError
from schema import Schema
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, PAI_YARN_CONFIG_SCHEMA, \
DLTS_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA, FRAMEWORKCONTROLLER_CONFIG_SCHEMA, \
tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from .common_utils import print_error, print_warning, print_normal, get_yml_content
from .config_schema import NNIConfigSchema
from .common_utils import print_normal
def expand_path(experiment_config, key):
'''Change '~' to user home directory'''
......@@ -27,12 +22,10 @@ def parse_time(time):
'''Change the time to seconds'''
unit = time[-1]
if unit not in ['s', 'm', 'h', 'd']:
print_error('the unit of time could only from {s, m, h, d}')
exit(1)
raise SchemaError('the unit of time could only from {s, m, h, d}')
time = time[:-1]
if not time.isdigit():
print_error('time format error!')
exit(1)
raise SchemaError('time format error!')
parse_dict = {'s':1, 'm':60, 'h':3600, 'd':86400}
return int(time) * parse_dict[unit]
......@@ -101,100 +94,7 @@ def parse_path(experiment_config, config_path):
if experiment_config['trial'].get('paiConfigPath'):
parse_relative_path(root_path, experiment_config['trial'], 'paiConfigPath')
def validate_search_space_content(experiment_config):
'''Validate searchspace content,
if the searchspace file is not json format or its values does not contain _type and _value which must be specified,
it will not be a valid searchspace file'''
try:
search_space_content = json.load(open(experiment_config.get('searchSpacePath'), 'r'))
for value in search_space_content.values():
if not value.get('_type') or not value.get('_value'):
print_error('please use _type and _value to specify searchspace!')
exit(1)
except:
print_error('searchspace file is not a valid json format!')
exit(1)
def validate_kubeflow_operators(experiment_config):
'''Validate whether the kubeflow operators are valid'''
if experiment_config.get('kubeflowConfig'):
if experiment_config.get('kubeflowConfig').get('operator') == 'tf-operator':
if experiment_config.get('trial').get('master') is not None:
print_error('kubeflow with tf-operator can not set master')
exit(1)
if experiment_config.get('trial').get('worker') is None:
print_error('kubeflow with tf-operator must set worker')
exit(1)
elif experiment_config.get('kubeflowConfig').get('operator') == 'pytorch-operator':
if experiment_config.get('trial').get('ps') is not None:
print_error('kubeflow with pytorch-operator can not set ps')
exit(1)
if experiment_config.get('trial').get('master') is None:
print_error('kubeflow with pytorch-operator must set master')
exit(1)
if experiment_config.get('kubeflowConfig').get('storage') == 'nfs':
if experiment_config.get('kubeflowConfig').get('nfs') is None:
print_error('please set nfs configuration!')
exit(1)
elif experiment_config.get('kubeflowConfig').get('storage') == 'azureStorage':
if experiment_config.get('kubeflowConfig').get('azureStorage') is None:
print_error('please set azureStorage configuration!')
exit(1)
elif experiment_config.get('kubeflowConfig').get('storage') is None:
if experiment_config.get('kubeflowConfig').get('azureStorage'):
print_error('please set storage type!')
exit(1)
def validate_common_content(experiment_config):
'''Validate whether the common values in experiment_config is valid'''
if not experiment_config.get('trainingServicePlatform') or \
experiment_config.get('trainingServicePlatform') not in [
'local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts'
]:
print_error('Please set correct trainingServicePlatform!')
exit(1)
schema_dict = {
'local': LOCAL_CONFIG_SCHEMA,
'remote': REMOTE_CONFIG_SCHEMA,
'pai': PAI_CONFIG_SCHEMA,
'paiYarn': PAI_YARN_CONFIG_SCHEMA,
'kubeflow': KUBEFLOW_CONFIG_SCHEMA,
'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA,
'dlts': DLTS_CONFIG_SCHEMA,
}
separate_schema_dict = {
'tuner': tuner_schema_dict,
'advisor': advisor_schema_dict,
'assessor': assessor_schema_dict
}
separate_builtInName_dict = {
'tuner': 'builtinTunerName',
'advisor': 'builtinAdvisorName',
'assessor': 'builtinAssessorName'
}
try:
schema_dict.get(experiment_config['trainingServicePlatform']).validate(experiment_config)
for separate_key in separate_schema_dict.keys():
if experiment_config.get(separate_key):
if experiment_config[separate_key].get(separate_builtInName_dict[separate_key]):
validate = False
for key in separate_schema_dict[separate_key].keys():
if key.__contains__(experiment_config[separate_key][separate_builtInName_dict[separate_key]]):
Schema({**separate_schema_dict[separate_key][key]}).validate(experiment_config[separate_key])
validate = True
break
if not validate:
print_error('%s %s error!' % (separate_key, separate_builtInName_dict[separate_key]))
exit(1)
else:
Schema({**separate_schema_dict[separate_key]['customized']}).validate(experiment_config[separate_key])
except SchemaError as error:
print_error('Your config file is not correct, please check your config file content!')
print_error(error.code)
exit(1)
#set default value
def set_default_values(experiment_config):
if experiment_config.get('maxExecDuration') is None:
experiment_config['maxExecDuration'] = '999d'
if experiment_config.get('maxTrialNum') is None:
......@@ -204,124 +104,11 @@ def validate_common_content(experiment_config):
if experiment_config['machineList'][index].get('port') is None:
experiment_config['machineList'][index]['port'] = 22
def validate_customized_file(experiment_config, spec_key):
'''
check whether the file of customized tuner/assessor/advisor exists
spec_key: 'tuner', 'assessor', 'advisor'
'''
if experiment_config[spec_key].get('codeDir') and \
experiment_config[spec_key].get('classFileName') and \
experiment_config[spec_key].get('className'):
if not os.path.exists(os.path.join(
experiment_config[spec_key]['codeDir'],
experiment_config[spec_key]['classFileName'])):
print_error('%s file directory is not valid!'%(spec_key))
exit(1)
else:
print_error('%s file directory is not valid!'%(spec_key))
exit(1)
def parse_tuner_content(experiment_config):
'''Validate whether tuner in experiment_config is valid'''
if not experiment_config['tuner'].get('builtinTunerName'):
validate_customized_file(experiment_config, 'tuner')
def parse_assessor_content(experiment_config):
'''Validate whether assessor in experiment_config is valid'''
if experiment_config.get('assessor'):
if not experiment_config['assessor'].get('builtinAssessorName'):
validate_customized_file(experiment_config, 'assessor')
def parse_advisor_content(experiment_config):
'''Validate whether advisor in experiment_config is valid'''
if not experiment_config['advisor'].get('builtinAdvisorName'):
validate_customized_file(experiment_config, 'advisor')
def validate_annotation_content(experiment_config, spec_key, builtin_name):
'''
Valid whether useAnnotation and searchSpacePath is coexist
spec_key: 'advisor' or 'tuner'
builtin_name: 'builtinAdvisorName' or 'builtinTunerName'
'''
if experiment_config.get('useAnnotation'):
if experiment_config.get('searchSpacePath'):
print_error('If you set useAnnotation=true, please leave searchSpacePath empty')
exit(1)
else:
# validate searchSpaceFile
if experiment_config[spec_key].get(builtin_name) == 'NetworkMorphism':
return
if experiment_config[spec_key].get(builtin_name):
if experiment_config.get('searchSpacePath') is None:
print_error('Please set searchSpacePath!')
exit(1)
validate_search_space_content(experiment_config)
def validate_machine_list(experiment_config):
'''Validate machine list'''
if experiment_config.get('trainingServicePlatform') == 'remote' and experiment_config.get('machineList') is None:
print_error('Please set machineList!')
exit(1)
def validate_pai_config_path(experiment_config):
'''validate paiConfigPath field'''
if experiment_config.get('trainingServicePlatform') == 'pai':
if experiment_config.get('trial', {}).get('paiConfigPath'):
# validate commands
pai_config = get_yml_content(experiment_config['trial']['paiConfigPath'])
taskRoles_dict = pai_config.get('taskRoles')
if not taskRoles_dict:
print_error('Please set taskRoles in paiConfigPath config file!')
exit(1)
else:
pai_trial_fields_required_list = ['image', 'gpuNum', 'cpuNum', 'memoryMB', 'paiStorageConfigName', 'command']
for trial_field in pai_trial_fields_required_list:
if experiment_config['trial'].get(trial_field) is None:
print_error('Please set {0} in trial configuration,\
or set additional pai configuration file path in paiConfigPath!'.format(trial_field))
exit(1)
def validate_pai_trial_conifg(experiment_config):
'''validate the trial config in pai platform'''
if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']:
if experiment_config.get('trial').get('shmMB') and \
experiment_config['trial']['shmMB'] > experiment_config['trial']['memoryMB']:
print_error('shmMB should be no more than memoryMB!')
exit(1)
#backward compatibility
warning_information = '{0} is not supported in NNI anymore, please remove the field in config file!\
please refer https://github.com/microsoft/nni/blob/master/docs/en_US/TrainingService/PaiMode.md#run-an-experiment\
for the practices of how to get data and output model in trial code'
if experiment_config.get('trial').get('dataDir'):
print_warning(warning_information.format('dataDir'))
if experiment_config.get('trial').get('outputDir'):
print_warning(warning_information.format('outputDir'))
validate_pai_config_path(experiment_config)
def validate_eth0_device(experiment_config):
'''validate whether the machine has eth0 device'''
if experiment_config.get('trainingServicePlatform') not in ['local'] \
and not experiment_config.get('nniManagerIp') \
and 'eth0' not in netifaces.interfaces():
print_error('This machine does not contain eth0 network device, please set nniManagerIp in config file!')
exit(1)
def validate_all_content(experiment_config, config_path):
'''Validate whether experiment_config is valid'''
parse_path(experiment_config, config_path)
validate_common_content(experiment_config)
validate_eth0_device(experiment_config)
validate_pai_trial_conifg(experiment_config)
set_default_values(experiment_config)
NNIConfigSchema().validate(experiment_config)
experiment_config['maxExecDuration'] = parse_time(experiment_config['maxExecDuration'])
if experiment_config.get('advisor'):
if experiment_config.get('assessor') or experiment_config.get('tuner'):
print_error('advisor could not be set with assessor or tuner simultaneously!')
exit(1)
parse_advisor_content(experiment_config)
validate_annotation_content(experiment_config, 'advisor', 'builtinAdvisorName')
else:
if not experiment_config.get('tuner'):
raise Exception('Please provide tuner spec!')
parse_tuner_content(experiment_config)
parse_assessor_content(experiment_config)
validate_annotation_content(experiment_config, 'tuner', 'builtinTunerName')
......@@ -12,7 +12,7 @@ from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment
log_trial, experiment_clean, platform_clean, experiment_list, \
monitor_experiment, export_trials_data, trial_codegen, webui_url, \
get_config, log_stdout, log_stderr, search_space_auto_gen, webui_nas
from .package_management import package_install, package_show
from .package_management import package_install, package_uninstall, package_show, package_list
from .constants import DEFAULT_REST_PORT
from .tensorboard_utils import start_tensorboard, stop_tensorboard
init(autoreset=True)
......@@ -196,11 +196,22 @@ def parse_args():
# add subparsers for parser_package
parser_package_subparsers = parser_package.add_subparsers()
parser_package_install = parser_package_subparsers.add_parser('install', help='install packages')
parser_package_install.add_argument('--name', '-n', dest='name', help='package name to be installed')
parser_package_install.add_argument('source', nargs='?', help='installation source, can be a directory or whl file')
parser_package_install.add_argument('--name', '-n', dest='name', help='package name to be installed', required=False)
parser_package_install.set_defaults(func=package_install)
parser_package_uninstall = parser_package_subparsers.add_parser('uninstall', help='uninstall packages')
parser_package_uninstall.add_argument('name', nargs=1, help='package name to be uninstalled')
parser_package_uninstall.set_defaults(func=package_uninstall)
parser_package_show = parser_package_subparsers.add_parser('show', help='show the information of packages')
parser_package_show.add_argument('name', nargs=1, help='builtin name of the package')
parser_package_show.set_defaults(func=package_show)
parser_package_list = parser_package_subparsers.add_parser('list', help='list installed packages')
parser_package_list.add_argument('--all', action='store_true', help='list all builtin packages')
parser_package_list.set_defaults(func=package_list)
#parse tensorboard command
parser_tensorboard = subparsers.add_parser('tensorboard', help='manage tensorboard')
parser_tensorboard_subparsers = parser_tensorboard.add_subparsers()
......
......@@ -13,13 +13,14 @@ from datetime import datetime, timezone
from pathlib import Path
from subprocess import Popen
from pyhdfs import HdfsClient
from nni.package_utils import get_nni_installation_path
from nni_annotation import expand_annotations
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url
from .config_utils import Config, Experiments
from .constants import NNICTL_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content, get_nni_installation_path
from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content
from .command_utils import check_output_command, kill_command
from .ssh_utils import create_ssh_sftp_client, remove_remote_directory
......
......@@ -2,23 +2,183 @@
# Licensed under the MIT license.
import os
from collections import defaultdict
import json
import pkginfo
import nni
from .constants import PACKAGE_REQUIREMENTS
from .common_utils import print_error
from .command_utils import install_requirements_command
from nni.package_utils import read_installed_package_meta, get_installed_package_meta, \
write_package_meta, get_builtin_algo_meta, get_not_installable_builtin_names, ALGO_TYPES
def process_install(package_name):
if PACKAGE_REQUIREMENTS.get(package_name) is None:
print_error('{0} is not supported!' % package_name)
else:
requirements_path = os.path.join(nni.__path__[0], PACKAGE_REQUIREMENTS[package_name])
install_requirements_command(requirements_path)
from .constants import INSTALLABLE_PACKAGE_META
from .common_utils import print_error, print_green
from .command_utils import install_requirements_command, call_pip_install, call_pip_uninstall
PACKAGE_TYPES = ['tuner', 'assessor', 'advisor']
def install_by_name(package_name):
if package_name not in INSTALLABLE_PACKAGE_META:
raise RuntimeError('{} is not found in installable packages!'.format(package_name))
requirements_path = os.path.join(nni.__path__[0], INSTALLABLE_PACKAGE_META[package_name]['code_sub_dir'], 'requirements.txt')
assert os.path.exists(requirements_path)
return install_requirements_command(requirements_path)
def package_install(args):
'''install packages'''
process_install(args.name)
installed = False
try:
if args.name:
if install_by_name(args.name) == 0:
package_meta = {}
package_meta['type'] = INSTALLABLE_PACKAGE_META[args.name]['type']
package_meta['name'] = args.name
package_meta['class_name'] = INSTALLABLE_PACKAGE_META[args.name]['class_name']
package_meta['class_args_validator'] = INSTALLABLE_PACKAGE_META[args.name]['class_args_validator']
save_package_meta_data(package_meta)
print_green('{} installed!'.format(args.name))
installed = True
else:
package_meta = get_nni_meta(args.source)
if package_meta:
if call_pip_install(args.source) == 0:
save_package_meta_data(package_meta)
print_green('{} installed!'.format(package_meta['name']))
installed = True
except Exception as e:
print_error(e)
if not installed:
print_error('installation failed!')
def package_uninstall(args):
'''uninstall packages'''
name = args.name[0]
if name in get_not_installable_builtin_names():
print_error('{} can not be uninstalled!'.format(name))
exit(1)
meta = get_installed_package_meta(None, name)
if meta is None:
print_error('package {} not found!'.format(name))
return
if 'installed_package' in meta:
call_pip_uninstall(meta['installed_package'])
if remove_package_meta_data(name):
print_green('{} uninstalled sucessfully!'.format(name))
else:
print_error('Failed to uninstall {}!'.format(name))
def package_show(args):
'''show all packages'''
print(' '.join(PACKAGE_REQUIREMENTS.keys()))
'''show specified packages'''
builtin_name = args.name[0]
meta = get_builtin_algo_meta(builtin_name=builtin_name)
if meta:
print(json.dumps(meta, indent=4))
else:
print_error('package {} not found'.format(builtin_name))
def print_package_list(meta):
print('+-----------------+------------+-----------+--------=-------------+------------------------------------------+')
print('| Name | Type | Installed | Class Name | Module Name |')
print('+-----------------+------------+-----------+----------------------+------------------------------------------+')
MAX_MODULE_NAME = 38
for t in ['tuners', 'assessors', 'advisors']:
for p in meta[t]:
module_name = '.'.join(p['class_name'].split('.')[:-1])
if len(module_name) > MAX_MODULE_NAME:
module_name = module_name[:MAX_MODULE_NAME-3] + '...'
class_name = p['class_name'].split('.')[-1]
print('| {:15s} | {:10s} | {:9s} | {:20s} | {:40s} |'.format(p['name'], t, p['installed'], class_name, module_name[:38]))
print('+-----------------+------------+-----------+----------------------+------------------------------------------+')
def package_list(args):
'''list all packages'''
if args.all:
meta = get_builtin_algo_meta()
else:
meta = read_installed_package_meta()
installed_names = defaultdict(list)
for t in ['tuners', 'assessors', 'advisors']:
for p in meta[t]:
p['installed'] = 'Yes'
installed_names[t].append(p['name'])
for k, v in INSTALLABLE_PACKAGE_META.items():
t = v['type']+'s'
if k not in installed_names[t]:
meta[t].append({
'name': k,
'class_name': v['class_name'],
'class_args_validator': v['class_args_validator'],
'installed': 'No'
})
print_package_list(meta)
def save_package_meta_data(meta_data):
assert meta_data['type'] in PACKAGE_TYPES
assert 'name' in meta_data
assert 'class_name' in meta_data
config = read_installed_package_meta()
if meta_data['name'] in [x['name'] for x in config[meta_data['type']+'s']]:
raise ValueError('name %s already installed' % meta_data['name'])
package_meta = {k: meta_data[k] for k in ['name', 'class_name', 'class_args_validator'] if k in meta_data}
if 'package_name' in meta_data:
package_meta['installed_package'] = meta_data['package_name']
config[meta_data['type']+'s'].append(package_meta)
write_package_meta(config)
def remove_package_meta_data(name):
config = read_installed_package_meta()
updated = False
for t in ALGO_TYPES:
for meta in config[t]:
if meta['name'] == name:
config[t].remove(meta)
updated = True
if updated:
write_package_meta(config)
return True
return False
def get_nni_meta(source):
if not os.path.exists(source):
print_error('{} does not exist'.format(source))
return None
if os.path.isdir(source):
if not os.path.exists(os.path.join(source, 'setup.py')):
print_error('setup.py not found')
return None
pkg = pkginfo.Develop(source)
else:
if not source.endswith('.whl'):
print_error('File name {} must ends with \'.whl\''.format(source))
return False
pkg = pkginfo.Wheel(source)
classifiers = pkg.classifiers
meta = parse_classifiers(classifiers)
meta['package_name'] = pkg.name
return meta
def parse_classifiers(classifiers):
parts = []
for c in classifiers:
if c.startswith('NNI Package'):
parts = [x.strip() for x in c.split('::')]
break
if len(parts) < 4 or not all(parts):
raise ValueError('Can not find correct NNI meta data in package classifiers.')
meta = {
'type': parts[1],
'name': parts[2],
'class_name': parts[3]
}
if len(parts) >= 5:
meta['class_args_validator'] = parts[4]
return meta
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment