Unverified Commit eea50784 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Dev pylint (#1697)

Fix pylint errors
parent 1f9b7617
...@@ -8,16 +8,33 @@ jobs: ...@@ -8,16 +8,33 @@ jobs:
PYTHON_VERSION: '3.6' PYTHON_VERSION: '3.6'
steps: steps:
- script: python3 -m pip install --upgrade pip setuptools --user - script: |
python3 -m pip install --upgrade pip setuptools --user
python3 -m pip install pylint==2.3.1 astroid==2.2.5 --user
python3 -m pip install coverage --user
displayName: 'Install python tools' displayName: 'Install python tools'
- script: |
source install.sh
displayName: 'Install nni toolkit via source code'
- script: | - script: |
python3 -m pip install torch==0.4.1 --user python3 -m pip install torch==0.4.1 --user
python3 -m pip install torchvision==0.2.1 --user python3 -m pip install torchvision==0.2.1 --user
python3 -m pip install tensorflow==1.13.1 --user python3 -m pip install tensorflow==1.13.1 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install gym onnx --user
sudo apt-get install swig -y
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
displayName: 'Install dependencies' displayName: 'Install dependencies'
- script: | - script: |
source install.sh set -e
displayName: 'Install nni toolkit via source code' python3 -m pylint --rcfile pylintrc nni_annotation
python3 -m pylint --rcfile pylintrc nni_cmd
python3 -m pylint --rcfile pylintrc nni_gpu_tool
python3 -m pylint --rcfile pylintrc nni_trial_tool
python3 -m pylint --rcfile pylintrc nni
python3 -m pylint --rcfile pylintrc nnicli
displayName: 'Run pylint'
- script: | - script: |
python3 -m pip install flake8 --user python3 -m pip install flake8 --user
IGNORE=./tools/nni_annotation/testcase/*:F821,./examples/trials/mnist-nas/*/mnist*.py:F821,./examples/trials/nas_cifar10/src/cifar10/general_child.py:F821 IGNORE=./tools/nni_annotation/testcase/*:F821,./examples/trials/mnist-nas/*/mnist*.py:F821,./examples/trials/nas_cifar10/src/cifar10/general_child.py:F821
......
import setuptools import setuptools
setuptools.setup( setuptools.setup(
name = 'nnicli', name='nnicli',
version = '999.0.0-developing', version='999.0.0-developing',
packages = setuptools.find_packages(), packages=setuptools.find_packages(),
python_requires = '>=3.5', python_requires='>=3.5',
install_requires = [ install_requires=[
'requests' 'requests'
], ],
author = 'Microsoft NNI Team', author='Microsoft NNI Team',
author_email = 'nni@microsoft.com', author_email='nni@microsoft.com',
description = 'nnicli for Neural Network Intelligence project', description='nnicli for Neural Network Intelligence project',
license = 'MIT', license='MIT',
url = 'https://github.com/Microsoft/nni', url='https://github.com/Microsoft/nni',
) )
...@@ -80,7 +80,7 @@ class Compressor: ...@@ -80,7 +80,7 @@ class Compressor:
Returns Returns
------- -------
ret : config or None ret : config or None
the retrieved configuration for this layer, if None, this layer should the retrieved configuration for this layer, if None, this layer should
not be compressed not be compressed
""" """
ret = None ret = None
......
...@@ -73,7 +73,7 @@ class Compressor: ...@@ -73,7 +73,7 @@ class Compressor:
Returns Returns
------- -------
ret : config or None ret : config or None
the retrieved configuration for this layer, if None, this layer should the retrieved configuration for this layer, if None, this layer should
not be compressed not be compressed
""" """
ret = None ret = None
......
...@@ -143,14 +143,14 @@ class CategoricalPd(Pd): ...@@ -143,14 +143,14 @@ class CategoricalPd(Pd):
re_masked_res = tf.reshape(masked_res, [-1, self.size]) re_masked_res = tf.reshape(masked_res, [-1, self.size])
u = tf.random_uniform(tf.shape(re_masked_res), dtype=self.logits.dtype) u = tf.random_uniform(tf.shape(re_masked_res), dtype=self.logits.dtype)
return tf.argmax(re_masked_res - tf.log(-tf.log(u)), axis=-1) return tf.argmax(re_masked_res - tf.log(-1*tf.log(u)), axis=-1)
else: else:
u = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype) u = tf.random_uniform(tf.shape(self.logits), dtype=self.logits.dtype)
return tf.argmax(self.logits - tf.log(-tf.log(u)), axis=-1) return tf.argmax(self.logits - tf.log(-1*tf.log(u)), axis=-1)
@classmethod @classmethod
def fromflat(cls, flat): def fromflat(cls, flat):
return cls(flat) return cls(flat) # pylint: disable=no-value-for-parameter
class CategoricalPdType(PdType): class CategoricalPdType(PdType):
""" """
......
...@@ -107,7 +107,7 @@ class PolicyWithValue: ...@@ -107,7 +107,7 @@ class PolicyWithValue:
def sample(logits, mask_npinf): def sample(logits, mask_npinf):
new_logits = tf.math.add(logits, mask_npinf) new_logits = tf.math.add(logits, mask_npinf)
u = tf.random_uniform(tf.shape(new_logits), dtype=logits.dtype) u = tf.random_uniform(tf.shape(new_logits), dtype=logits.dtype)
return tf.argmax(new_logits - tf.log(-tf.log(u)), axis=-1) return tf.argmax(new_logits - tf.log(-1*tf.log(u)), axis=-1)
def neglogp(logits, x): def neglogp(logits, x):
# return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x) # return tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=x)
......
...@@ -22,11 +22,9 @@ ppo_tuner.py including: ...@@ -22,11 +22,9 @@ ppo_tuner.py including:
class PPOTuner class PPOTuner
""" """
import os
import copy import copy
import logging import logging
import numpy as np import numpy as np
import json_tricks
from gym import spaces from gym import spaces
import nni import nni
...@@ -236,7 +234,8 @@ class PPOModel: ...@@ -236,7 +234,8 @@ class PPOModel:
nextnonterminal = 1.0 - trials_info.dones[t+1] nextnonterminal = 1.0 - trials_info.dones[t+1]
nextvalues = trials_info.values[t+1] nextvalues = trials_info.values[t+1]
delta = mb_rewards[t] + self.model_config.gamma * nextvalues * nextnonterminal - trials_info.values[t] delta = mb_rewards[t] + self.model_config.gamma * nextvalues * nextnonterminal - trials_info.values[t]
mb_advs[t] = lastgaelam = delta + self.model_config.gamma * self.model_config.lam * nextnonterminal * lastgaelam lastgaelam = delta + self.model_config.gamma * self.model_config.lam * nextnonterminal * lastgaelam
mb_advs[t] = lastgaelam # pylint: disable=unsupported-assignment-operation
mb_returns = mb_advs + trials_info.values mb_returns = mb_advs + trials_info.values
trials_info.update_rewards(mb_rewards, mb_returns) trials_info.update_rewards(mb_rewards, mb_returns)
...@@ -536,8 +535,10 @@ class PPOTuner(Tuner): ...@@ -536,8 +535,10 @@ class PPOTuner(Tuner):
# generate new trials # generate new trials
self.trials_result = [None for _ in range(self.inf_batch_size)] self.trials_result = [None for _ in range(self.inf_batch_size)]
mb_obs, mb_actions, mb_values, mb_neglogpacs, mb_dones, last_values = self.model.inference(self.inf_batch_size) mb_obs, mb_actions, mb_values, mb_neglogpacs, mb_dones, last_values = self.model.inference(self.inf_batch_size)
self.trials_info = TrialsInfo(mb_obs, mb_actions, mb_values, mb_neglogpacs, self.trials_info = TrialsInfo(mb_obs, mb_actions,
mb_dones, last_values, self.inf_batch_size) mb_values, mb_neglogpacs,
mb_dones, last_values,
self.inf_batch_size)
# check credit and submit new trials # check credit and submit new trials
for _ in range(self.credit): for _ in range(self.credit):
trial_info_idx, actions = self.trials_info.get_next() trial_info_idx, actions = self.trials_info.get_next()
...@@ -581,8 +582,8 @@ class PPOTuner(Tuner): ...@@ -581,8 +582,8 @@ class PPOTuner(Tuner):
assert trial_info_idx is not None assert trial_info_idx is not None
# use mean of finished trials as the result of this failed trial # use mean of finished trials as the result of this failed trial
values = [val for val in self.trials_result if val is not None] values = [val for val in self.trials_result if val is not None]
logger.warning('zql values: {0}'.format(values)) logger.warning('zql values: %s', values)
self.trials_result[trial_info_idx] = (sum(values) / len(values)) if len(values) > 0 else 0 self.trials_result[trial_info_idx] = (sum(values) / len(values)) if values else 0
self.finished_trials += 1 self.finished_trials += 1
if self.finished_trials == self.inf_batch_size: if self.finished_trials == self.inf_batch_size:
self._next_round_inference() self._next_round_inference()
......
...@@ -56,7 +56,7 @@ def seq_to_batch(h, flat=False): ...@@ -56,7 +56,7 @@ def seq_to_batch(h, flat=False):
def lstm(xs, ms, s, scope, nh, init_scale=1.0): def lstm(xs, ms, s, scope, nh, init_scale=1.0):
"""lstm cell""" """lstm cell"""
nbatch, nin = [v.value for v in xs[0].get_shape()] _, nin = [v.value for v in xs[0].get_shape()] # the first is nbatch
with tf.variable_scope(scope): with tf.variable_scope(scope):
wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale)) wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale))
wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale)) wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale))
......
from .smac_tuner import SMACTuner from .smac_tuner import SMACTuner
\ No newline at end of file
...@@ -39,7 +39,6 @@ from nni.utils import OptimizeMode, extract_scalar_reward ...@@ -39,7 +39,6 @@ from nni.utils import OptimizeMode, extract_scalar_reward
from .convert_ss_to_scenario import generate_scenario from .convert_ss_to_scenario import generate_scenario
class SMACTuner(Tuner): class SMACTuner(Tuner):
""" """
Parameters Parameters
......
...@@ -3,7 +3,7 @@ import sys ...@@ -3,7 +3,7 @@ import sys
import os import os
import signal import signal
import psutil import psutil
from .common_utils import print_error, print_normal, print_warning from .common_utils import print_error
def check_output_command(file_path, head=None, tail=None): def check_output_command(file_path, head=None, tail=None):
......
...@@ -21,10 +21,10 @@ ...@@ -21,10 +21,10 @@
import os import os
import sys import sys
import json import json
import ruamel.yaml as yaml
import psutil
import socket import socket
from pathlib import Path from pathlib import Path
import ruamel.yaml as yaml
import psutil
from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO, COLOR_RED_FORMAT, COLOR_YELLOW_FORMAT from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO, COLOR_RED_FORMAT, COLOR_YELLOW_FORMAT
def get_yml_content(file_path): def get_yml_content(file_path):
...@@ -34,6 +34,7 @@ def get_yml_content(file_path): ...@@ -34,6 +34,7 @@ def get_yml_content(file_path):
return yaml.load(file, Loader=yaml.Loader) return yaml.load(file, Loader=yaml.Loader)
except yaml.scanner.ScannerError as err: except yaml.scanner.ScannerError as err:
print_error('yaml file format error!') print_error('yaml file format error!')
print_error(err)
exit(1) exit(1)
except Exception as exception: except Exception as exception:
print_error(exception) print_error(exception)
...@@ -46,6 +47,7 @@ def get_json_content(file_path): ...@@ -46,6 +47,7 @@ def get_json_content(file_path):
return json.load(file) return json.load(file)
except TypeError as err: except TypeError as err:
print_error('json file format error!') print_error('json file format error!')
print_error(err)
return None return None
def print_error(content): def print_error(content):
...@@ -70,7 +72,7 @@ def detect_process(pid): ...@@ -70,7 +72,7 @@ def detect_process(pid):
def detect_port(port): def detect_port(port):
'''Detect if the port is used''' '''Detect if the port is used'''
socket_test = socket.socket(socket.AF_INET,socket.SOCK_STREAM) socket_test = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try: try:
socket_test.connect(('127.0.0.1', int(port))) socket_test.connect(('127.0.0.1', int(port)))
socket_test.close() socket_test.close()
...@@ -79,7 +81,7 @@ def detect_port(port): ...@@ -79,7 +81,7 @@ def detect_port(port):
return False return False
def get_user(): def get_user():
if sys.platform =='win32': if sys.platform == 'win32':
return os.environ['USERNAME'] return os.environ['USERNAME']
else: else:
return os.environ['USER'] return os.environ['USER']
......
...@@ -19,13 +19,13 @@ ...@@ -19,13 +19,13 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os import os
from schema import Schema, And, Use, Optional, Regex, Or from schema import Schema, And, Optional, Regex, Or
from .constants import SCHEMA_TYPE_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_PATH_ERROR from .constants import SCHEMA_TYPE_ERROR, SCHEMA_RANGE_ERROR, SCHEMA_PATH_ERROR
def setType(key, type): def setType(key, valueType):
'''check key type''' '''check key type'''
return And(type, error=SCHEMA_TYPE_ERROR % (key, type.__name__)) return And(valueType, error=SCHEMA_TYPE_ERROR % (key, valueType.__name__))
def setChoice(key, *args): def setChoice(key, *args):
'''check choice''' '''check choice'''
...@@ -47,7 +47,7 @@ common_schema = { ...@@ -47,7 +47,7 @@ common_schema = {
'experimentName': setType('experimentName', str), 'experimentName': setType('experimentName', str),
Optional('description'): setType('description', str), Optional('description'): setType('description', str),
'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999), 'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999),
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$',error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')), Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999), Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller'), 'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'), Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
...@@ -106,7 +106,7 @@ tuner_schema_dict = { ...@@ -106,7 +106,7 @@ tuner_schema_dict = {
'builtinTunerName': 'NetworkMorphism', 'builtinTunerName': 'NetworkMorphism',
Optional('classArgs'): { Optional('classArgs'): {
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'), Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('task'): setChoice('task', 'cv','nlp','common'), Optional('task'): setChoice('task', 'cv', 'nlp', 'common'),
Optional('input_width'): setType('input_width', int), Optional('input_width'): setType('input_width', int),
Optional('input_channel'): setType('input_channel', int), Optional('input_channel'): setType('input_channel', int),
Optional('n_output_node'): setType('n_output_node', int), Optional('n_output_node'): setType('n_output_node', int),
...@@ -139,7 +139,7 @@ tuner_schema_dict = { ...@@ -139,7 +139,7 @@ tuner_schema_dict = {
Optional('selection_num_warm_up'): setType('selection_num_warm_up', int), Optional('selection_num_warm_up'): setType('selection_num_warm_up', int),
Optional('selection_num_starting_points'): setType('selection_num_starting_points', int), Optional('selection_num_starting_points'): setType('selection_num_starting_points', int),
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}, },
'PPOTuner': { 'PPOTuner': {
...@@ -232,35 +232,35 @@ assessor_schema_dict = { ...@@ -232,35 +232,35 @@ assessor_schema_dict = {
} }
common_trial_schema = { common_trial_schema = {
'trial':{ 'trial':{
'command': setType('command', str), 'command': setType('command', str),
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode') Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode')
} }
} }
pai_trial_schema = { pai_trial_schema = {
'trial':{ 'trial':{
'command': setType('command', str), 'command': setType('command', str),
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999), 'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999), 'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int), 'memoryMB': setType('memoryMB', int),
'image': setType('image', str), 'image': setType('image', str),
Optional('authFile'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'authFile'), Optional('authFile'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'authFile'),
Optional('shmMB'): setType('shmMB', int), Optional('shmMB'): setType('shmMB', int),
Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\ Optional('dataDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'), error='ERROR: dataDir format error, dataDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\ Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'), error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
Optional('virtualCluster'): setType('virtualCluster', str), Optional('virtualCluster'): setType('virtualCluster', str),
Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'), Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
Optional('portList'): [{ Optional('portList'): [{
"label": setType('label', str), "label": setType('label', str),
"beginAt": setType('beginAt', int), "beginAt": setType('beginAt', int),
"portNumber": setType('portNumber', int) "portNumber": setType('portNumber', int)
}] }]
} }
} }
...@@ -273,7 +273,7 @@ pai_config_schema = { ...@@ -273,7 +273,7 @@ pai_config_schema = {
} }
kubeflow_trial_schema = { kubeflow_trial_schema = {
'trial':{ 'trial':{
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'), Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
Optional('ps'): { Optional('ps'): {
...@@ -315,7 +315,7 @@ kubeflow_config_schema = { ...@@ -315,7 +315,7 @@ kubeflow_config_schema = {
'server': setType('server', str), 'server': setType('server', str),
'path': setType('path', str) 'path': setType('path', str)
} }
},{ }, {
'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'), 'operator': setChoice('operator', 'tf-operator', 'pytorch-operator'),
'apiVersion': setType('apiVersion', str), 'apiVersion': setType('apiVersion', str),
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'), Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
...@@ -363,7 +363,7 @@ frameworkcontroller_config_schema = { ...@@ -363,7 +363,7 @@ frameworkcontroller_config_schema = {
'server': setType('server', str), 'server': setType('server', str),
'path': setType('path', str) 'path': setType('path', str)
} }
},{ }, {
Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'), Optional('storage'): setChoice('storage', 'nfs', 'azureStorage'),
Optional('serviceAccountName'): setType('serviceAccountName', str), Optional('serviceAccountName'): setType('serviceAccountName', str),
'keyVault': { 'keyVault': {
...@@ -383,24 +383,24 @@ frameworkcontroller_config_schema = { ...@@ -383,24 +383,24 @@ frameworkcontroller_config_schema = {
} }
machine_list_schema = { machine_list_schema = {
Optional('machineList'):[Or({ Optional('machineList'):[Or({
'ip': setType('ip', str), 'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535), Optional('port'): setNumberRange('port', int, 1, 65535),
'username': setType('username', str), 'username': setType('username', str),
'passwd': setType('passwd', str), 'passwd': setType('passwd', str),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
Optional('useActiveGpu'): setType('useActiveGpu', bool) Optional('useActiveGpu'): setType('useActiveGpu', bool)
},{ }, {
'ip': setType('ip', str), 'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535), Optional('port'): setNumberRange('port', int, 1, 65535),
'username': setType('username', str), 'username': setType('username', str),
'sshKeyPath': setPathCheck('sshKeyPath'), 'sshKeyPath': setPathCheck('sshKeyPath'),
Optional('passphrase'): setType('passphrase', str), Optional('passphrase'): setType('passphrase', str),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
Optional('useActiveGpu'): setType('useActiveGpu', bool) Optional('useActiveGpu'): setType('useActiveGpu', bool)
})] })]
} }
LOCAL_CONFIG_SCHEMA = Schema({**common_schema, **common_trial_schema}) LOCAL_CONFIG_SCHEMA = Schema({**common_schema, **common_trial_schema})
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
import os import os
import json import json
import shutil
from .constants import NNICTL_HOME_DIR from .constants import NNICTL_HOME_DIR
class Config: class Config:
...@@ -73,29 +72,29 @@ class Experiments: ...@@ -73,29 +72,29 @@ class Experiments:
self.experiment_file = os.path.join(NNICTL_HOME_DIR, '.experiment') self.experiment_file = os.path.join(NNICTL_HOME_DIR, '.experiment')
self.experiments = self.read_file() self.experiments = self.read_file()
def add_experiment(self, id, port, time, file_name, platform): def add_experiment(self, expId, port, time, file_name, platform):
'''set {key:value} paris to self.experiment''' '''set {key:value} paris to self.experiment'''
self.experiments[id] = {} self.experiments[expId] = {}
self.experiments[id]['port'] = port self.experiments[expId]['port'] = port
self.experiments[id]['startTime'] = time self.experiments[expId]['startTime'] = time
self.experiments[id]['endTime'] = 'N/A' self.experiments[expId]['endTime'] = 'N/A'
self.experiments[id]['status'] = 'INITIALIZED' self.experiments[expId]['status'] = 'INITIALIZED'
self.experiments[id]['fileName'] = file_name self.experiments[expId]['fileName'] = file_name
self.experiments[id]['platform'] = platform self.experiments[expId]['platform'] = platform
self.write_file() self.write_file()
def update_experiment(self, id, key, value): def update_experiment(self, expId, key, value):
'''Update experiment''' '''Update experiment'''
if id not in self.experiments: if expId not in self.experiments:
return False return False
self.experiments[id][key] = value self.experiments[expId][key] = value
self.write_file() self.write_file()
return True return True
def remove_experiment(self, id): def remove_experiment(self, expId):
'''remove an experiment by id''' '''remove an experiment by id'''
if id in self.experiments: if id in self.experiments:
self.experiments.pop(id) self.experiments.pop(expId)
self.write_file() self.write_file()
def get_all_experiments(self): def get_all_experiments(self):
...@@ -109,7 +108,7 @@ class Experiments: ...@@ -109,7 +108,7 @@ class Experiments:
json.dump(self.experiments, file) json.dump(self.experiments, file)
except IOError as error: except IOError as error:
print('Error:', error) print('Error:', error)
return return ''
def read_file(self): def read_file(self):
'''load config from local file''' '''load config from local file'''
...@@ -119,4 +118,4 @@ class Experiments: ...@@ -119,4 +118,4 @@ class Experiments:
return json.load(file) return json.load(file)
except ValueError: except ValueError:
return {} return {}
return {} return {}
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
import os import os
from colorama import Fore from colorama import Fore
NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), '.local', 'nnictl') NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), '.local', 'nnictl')
ERROR_INFO = 'ERROR: %s' ERROR_INFO = 'ERROR: %s'
...@@ -58,7 +58,8 @@ LOG_HEADER = '------------------------------------------------------------------ ...@@ -58,7 +58,8 @@ LOG_HEADER = '------------------------------------------------------------------
'-----------------------------------------------------------------------\n' '-----------------------------------------------------------------------\n'
EXPERIMENT_START_FAILED_INFO = 'There is an experiment running in the port %d, please stop it first or set another port!\n' \ EXPERIMENT_START_FAILED_INFO = 'There is an experiment running in the port %d, please stop it first or set another port!\n' \
'You could use \'nnictl stop --port [PORT]\' command to stop an experiment!\nOr you could use \'nnictl create --config [CONFIG_PATH] --port [PORT]\' to set port!\n' 'You could use \'nnictl stop --port [PORT]\' command to stop an experiment!\nOr you could ' \
'use \'nnictl create --config [CONFIG_PATH] --port [PORT]\' to set port!\n'
EXPERIMENT_INFORMATION_FORMAT = '----------------------------------------------------------------------------------------\n' \ EXPERIMENT_INFORMATION_FORMAT = '----------------------------------------------------------------------------------------\n' \
' Experiment information\n' \ ' Experiment information\n' \
......
...@@ -22,22 +22,21 @@ ...@@ -22,22 +22,21 @@
import json import json
import os import os
import sys import sys
import shutil
import string import string
from subprocess import Popen, PIPE, call, check_output, check_call, CalledProcessError import random
import site
import time
import tempfile import tempfile
from subprocess import Popen, check_call, CalledProcessError
from nni_annotation import expand_annotations, generate_search_space
from nni.constants import ModuleName, AdvisorModuleName from nni.constants import ModuleName, AdvisorModuleName
from nni_annotation import *
from .launcher_utils import validate_all_content from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response from .rest_utils import rest_put, rest_post, check_rest_server, check_response
from .url_utils import cluster_metadata_url, experiment_url, get_local_urls from .url_utils import cluster_metadata_url, experiment_url, get_local_urls
from .config_utils import Config, Experiments from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, detect_process, detect_port, get_user, get_python_dir from .common_utils import get_yml_content, get_json_content, print_error, print_normal, \
from .constants import * detect_port, get_user, get_python_dir
import random from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, PACKAGE_REQUIREMENTS
import site
import time
from pathlib import Path
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment from .nnictl_utils import update_experiment
...@@ -83,7 +82,8 @@ def get_nni_installation_path(): ...@@ -83,7 +82,8 @@ def get_nni_installation_path():
python_dir = os.getenv('VIRTUAL_ENV') python_dir = os.getenv('VIRTUAL_ENV')
else: else:
python_sitepackage = site.getsitepackages()[0] python_sitepackage = site.getsitepackages()[0]
# If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given that nni exists there # If system-wide python is used, we will give priority to using `local sitepackage`--"usersitepackages()" given
# that nni exists there
if python_sitepackage.startswith('/usr') or python_sitepackage.startswith('/Library'): if python_sitepackage.startswith('/usr') or python_sitepackage.startswith('/Library'):
python_dir = try_installation_path_sequentially(site.getusersitepackages(), site.getsitepackages()[0]) python_dir = try_installation_path_sequentially(site.getusersitepackages(), site.getsitepackages()[0])
else: else:
...@@ -98,7 +98,6 @@ def get_nni_installation_path(): ...@@ -98,7 +98,6 @@ def get_nni_installation_path():
def start_rest_server(port, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None): def start_rest_server(port, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None):
'''Run nni manager process''' '''Run nni manager process'''
nni_config = Config(config_file_name)
if detect_port(port): if detect_port(port):
print_error('Port %s is used by another process, please reset the port!\n' \ print_error('Port %s is used by another process, please reset the port!\n' \
'You could use \'nnictl create --help\' to get help information' % port) 'You could use \'nnictl create --help\' to get help information' % port)
...@@ -114,7 +113,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None ...@@ -114,7 +113,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
entry_dir = get_nni_installation_path() entry_dir = get_nni_installation_path()
entry_file = os.path.join(entry_dir, 'main.js') entry_file = os.path.join(entry_dir, 'main.js')
node_command = 'node' node_command = 'node'
if sys.platform == 'win32': if sys.platform == 'win32':
node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe') node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
...@@ -132,7 +131,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None ...@@ -132,7 +131,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
cmds += ['--experiment_id', experiment_id] cmds += ['--experiment_id', experiment_id]
stdout_full_path, stderr_full_path = get_log_path(config_file_name) stdout_full_path, stderr_full_path = get_log_path(config_file_name)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())) time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
#add time information in the header of log files #add time information in the header of log files
log_header = LOG_HEADER % str(time_now) log_header = LOG_HEADER % str(time_now)
stdout_file.write(log_header) stdout_file.write(log_header)
...@@ -212,7 +211,7 @@ def setNNIManagerIp(experiment_config, port, config_file_name): ...@@ -212,7 +211,7 @@ def setNNIManagerIp(experiment_config, port, config_file_name):
if experiment_config.get('nniManagerIp') is None: if experiment_config.get('nniManagerIp') is None:
return True, None return True, None
ip_config_dict = dict() ip_config_dict = dict()
ip_config_dict['nni_manager_ip'] = { 'nniManagerIp' : experiment_config['nniManagerIp'] } ip_config_dict['nni_manager_ip'] = {'nniManagerIp': experiment_config['nniManagerIp']}
response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), REST_TIME_OUT) response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), REST_TIME_OUT)
err_message = None err_message = None
if not response or not response.status_code == 200: if not response or not response.status_code == 200:
...@@ -403,11 +402,12 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -403,11 +402,12 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
stdout_full_path, stderr_full_path = get_log_path(config_file_name) stdout_full_path, stderr_full_path = get_log_path(config_file_name)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file) check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file)
except CalledProcessError as e: except CalledProcessError:
print_error('some errors happen when import package %s.' %(package_name)) print_error('some errors happen when import package %s.' %(package_name))
print_log_content(config_file_name) print_log_content(config_file_name)
if package_name in PACKAGE_REQUIREMENTS: if package_name in PACKAGE_REQUIREMENTS:
print_error('If %s is not installed, it should be installed through \'nnictl package install --name %s\''%(package_name, package_name)) print_error('If %s is not installed, it should be installed through '\
'\'nnictl package install --name %s\''%(package_name, package_name))
exit(1) exit(1)
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
...@@ -416,7 +416,8 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -416,7 +416,8 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True): if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
log_level = 'debug' log_level = 'debug'
# start rest server # start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], mode, config_file_name, experiment_id, log_dir, log_level) rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, config_file_name, experiment_id, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid) nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation # Deal with annotation
if experiment_config.get('useAnnotation'): if experiment_config.get('useAnnotation'):
...@@ -450,8 +451,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -450,8 +451,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
exit(1) exit(1)
if mode != 'view': if mode != 'view':
# set platform configuration # set platform configuration
set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port, config_file_name, rest_process) set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
config_file_name, rest_process)
# start a new experiment # start a new experiment
print_normal('Starting experiment...') print_normal('Starting experiment...')
# set debug configuration # set debug configuration
...@@ -478,7 +480,8 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -478,7 +480,8 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
#save experiment information #save experiment information
nnictl_experiment_config = Experiments() nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name, experiment_config['trainingServicePlatform']) nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name,\
experiment_config['trainingServicePlatform'])
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
...@@ -503,7 +506,6 @@ def manage_stopped_experiment(args, mode): ...@@ -503,7 +506,6 @@ def manage_stopped_experiment(args, mode):
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
experiment_id = None experiment_id = None
experiment_endTime = None
#find the latest stopped experiment #find the latest stopped experiment
if not args.id: if not args.id:
print_error('Please set experiment id! \nYou could use \'nnictl {0} {id}\' to {0} a stopped experiment!\n' \ print_error('Please set experiment id! \nYou could use \'nnictl {0} {id}\' to {0} a stopped experiment!\n' \
......
...@@ -20,11 +20,11 @@ ...@@ -20,11 +20,11 @@
import os import os
import json import json
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA, FRAMEWORKCONTROLLER_CONFIG_SCHEMA, \ from schema import SchemaError
tuner_schema_dict, advisor_schema_dict, assessor_schema_dict from schema import Schema
from schema import SchemaMissingKeyError, SchemaForbiddenKeyError, SchemaUnexpectedTypeError, SchemaWrongKeyError, SchemaError from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\
from .common_utils import get_json_content, print_error, print_warning, print_normal FRAMEWORKCONTROLLER_CONFIG_SCHEMA, tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from schema import Schema, And, Use, Optional, Regex, Or from .common_utils import print_error, print_warning, print_normal
def expand_path(experiment_config, key): def expand_path(experiment_config, key):
'''Change '~' to user home directory''' '''Change '~' to user home directory'''
...@@ -164,11 +164,11 @@ def validate_common_content(experiment_config): ...@@ -164,11 +164,11 @@ def validate_common_content(experiment_config):
print_error('Please set correct trainingServicePlatform!') print_error('Please set correct trainingServicePlatform!')
exit(1) exit(1)
schema_dict = { schema_dict = {
'local': LOCAL_CONFIG_SCHEMA, 'local': LOCAL_CONFIG_SCHEMA,
'remote': REMOTE_CONFIG_SCHEMA, 'remote': REMOTE_CONFIG_SCHEMA,
'pai': PAI_CONFIG_SCHEMA, 'pai': PAI_CONFIG_SCHEMA,
'kubeflow': KUBEFLOW_CONFIG_SCHEMA, 'kubeflow': KUBEFLOW_CONFIG_SCHEMA,
'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA 'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA
} }
separate_schema_dict = { separate_schema_dict = {
'tuner': tuner_schema_dict, 'tuner': tuner_schema_dict,
......
...@@ -20,14 +20,18 @@ ...@@ -20,14 +20,18 @@
import argparse import argparse
import os
import pkg_resources import pkg_resources
from colorama import init
from .common_utils import print_error
from .launcher import create_experiment, resume_experiment, view_experiment from .launcher import create_experiment, resume_experiment, view_experiment
from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum, import_data from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum, import_data
from .nnictl_utils import * from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment, experiment_status,\
from .package_management import * log_trial, experiment_clean, platform_clean, experiment_list, \
from .constants import * monitor_experiment, export_trials_data, trial_codegen, webui_url, get_config, log_stdout, log_stderr
from .tensorboard_utils import * from .package_management import package_install, package_show
from colorama import init from .constants import DEFAULT_REST_PORT
from .tensorboard_utils import start_tensorboard, stop_tensorboard
init(autoreset=True) init(autoreset=True)
if os.environ.get('COVERAGE_PROCESS_START'): if os.environ.get('COVERAGE_PROCESS_START'):
...@@ -38,7 +42,7 @@ def nni_info(*args): ...@@ -38,7 +42,7 @@ def nni_info(*args):
if args[0].version: if args[0].version:
try: try:
print(pkg_resources.get_distribution('nni').version) print(pkg_resources.get_distribution('nni').version)
except pkg_resources.ResolutionError as err: except pkg_resources.ResolutionError:
print_error('Get version failed, please use `pip3 list | grep nni` to check nni version!') print_error('Get version failed, please use `pip3 list | grep nni` to check nni version!')
else: else:
print('please run "nnictl {positional argument} --help" to see nnictl guidance') print('please run "nnictl {positional argument} --help" to see nnictl guidance')
......
...@@ -20,15 +20,13 @@ ...@@ -20,15 +20,13 @@
import csv import csv
import os import os
import psutil
import json import json
from datetime import datetime, timezone
import time import time
import re import re
from pathlib import Path
from pyhdfs import HdfsClient, HdfsFileNotFoundException
import shutil import shutil
from subprocess import call, check_output from datetime import datetime, timezone
from pathlib import Path
from pyhdfs import HdfsClient
from nni_annotation import expand_annotations from nni_annotation import expand_annotations
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url
...@@ -102,7 +100,8 @@ def check_experiment_id(args, update=True): ...@@ -102,7 +100,8 @@ def check_experiment_id(args, update=True):
experiment_information = "" experiment_information = ""
for key in running_experiment_list: for key in running_experiment_list:
experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], \ experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], \
experiment_dict[key]['port'], experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])) experiment_dict[key]['port'], experiment_dict[key].get('platform'), experiment_dict[key]['startTime'],\
experiment_dict[key]['endTime']))
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1) exit(1)
elif not running_experiment_list: elif not running_experiment_list:
...@@ -157,23 +156,24 @@ def parse_ids(args): ...@@ -157,23 +156,24 @@ def parse_ids(args):
experiment_information = "" experiment_information = ""
for key in running_experiment_list: for key in running_experiment_list:
experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], \ experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], \
experiment_dict[key]['port'], experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])) experiment_dict[key]['port'], experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], \
experiment_dict[key]['endTime']))
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1) exit(1)
else: else:
result_list = running_experiment_list result_list = running_experiment_list
elif args.id.endswith('*'): elif args.id.endswith('*'):
for id in running_experiment_list: for expId in running_experiment_list:
if id.startswith(args.id[:-1]): if expId.startswith(args.id[:-1]):
result_list.append(id) result_list.append(expId)
elif args.id in running_experiment_list: elif args.id in running_experiment_list:
result_list.append(args.id) result_list.append(args.id)
else: else:
for id in running_experiment_list: for expId in running_experiment_list:
if id.startswith(args.id): if expId.startswith(args.id):
result_list.append(id) result_list.append(expId)
if len(result_list) > 1: if len(result_list) > 1:
print_error(args.id + ' is ambiguous, please choose ' + ' '.join(result_list) ) print_error(args.id + ' is ambiguous, please choose ' + ' '.join(result_list))
return None return None
if not result_list and (args.id or args.port): if not result_list and (args.id or args.port):
print_error('There are no experiments matched, please set correct experiment id or restful server port') print_error('There are no experiments matched, please set correct experiment id or restful server port')
...@@ -235,7 +235,6 @@ def stop_experiment(args): ...@@ -235,7 +235,6 @@ def stop_experiment(args):
for experiment_id in experiment_id_list: for experiment_id in experiment_id_list:
print_normal('Stoping experiment %s' % experiment_id) print_normal('Stoping experiment %s' % experiment_id)
nni_config = Config(experiment_dict[experiment_id]['fileName']) nni_config = Config(experiment_dict[experiment_id]['fileName'])
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid') rest_pid = nni_config.get_config('restServerPid')
if rest_pid: if rest_pid:
kill_command(rest_pid) kill_command(rest_pid)
...@@ -249,7 +248,7 @@ def stop_experiment(args): ...@@ -249,7 +248,7 @@ def stop_experiment(args):
nni_config.set_config('tensorboardPidList', []) nni_config.set_config('tensorboardPidList', [])
print_normal('Stop experiment success.') print_normal('Stop experiment success.')
experiment_config.update_experiment(experiment_id, 'status', 'STOPPED') experiment_config.update_experiment(experiment_id, 'status', 'STOPPED')
time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())) time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
experiment_config.update_experiment(experiment_id, 'endTime', str(time_now)) experiment_config.update_experiment(experiment_id, 'endTime', str(time_now))
def trial_ls(args): def trial_ls(args):
...@@ -401,9 +400,9 @@ def local_clean(directory): ...@@ -401,9 +400,9 @@ def local_clean(directory):
print_normal('removing folder {0}'.format(directory)) print_normal('removing folder {0}'.format(directory))
try: try:
shutil.rmtree(directory) shutil.rmtree(directory)
except FileNotFoundError as err: except FileNotFoundError:
print_error('{0} does not exist.'.format(directory)) print_error('{0} does not exist.'.format(directory))
def remote_clean(machine_list, experiment_id=None): def remote_clean(machine_list, experiment_id=None):
'''clean up remote data''' '''clean up remote data'''
for machine in machine_list: for machine in machine_list:
...@@ -418,7 +417,7 @@ def remote_clean(machine_list, experiment_id=None): ...@@ -418,7 +417,7 @@ def remote_clean(machine_list, experiment_id=None):
sftp = create_ssh_sftp_client(host, port, userName, passwd) sftp = create_ssh_sftp_client(host, port, userName, passwd)
print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir)) print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir))
remove_remote_directory(sftp, remote_dir) remove_remote_directory(sftp, remote_dir)
def hdfs_clean(host, user_name, output_dir, experiment_id=None): def hdfs_clean(host, user_name, output_dir, experiment_id=None):
'''clean up hdfs data''' '''clean up hdfs data'''
hdfs_client = HdfsClient(hosts='{0}:80'.format(host), user_name=user_name, webhdfs_path='/webhdfs/api/v1', timeout=5) hdfs_client = HdfsClient(hosts='{0}:80'.format(host), user_name=user_name, webhdfs_path='/webhdfs/api/v1', timeout=5)
...@@ -475,7 +474,7 @@ def experiment_clean(args): ...@@ -475,7 +474,7 @@ def experiment_clean(args):
machine_list = nni_config.get_config('experimentConfig').get('machineList') machine_list = nni_config.get_config('experimentConfig').get('machineList')
remote_clean(machine_list, experiment_id) remote_clean(machine_list, experiment_id)
elif platform == 'pai': elif platform == 'pai':
host = nni_config.get_config('experimentConfig').get('paiConfig').get('host') host = nni_config.get_config('experimentConfig').get('paiConfig').get('host')
user_name = nni_config.get_config('experimentConfig').get('paiConfig').get('userName') user_name = nni_config.get_config('experimentConfig').get('paiConfig').get('userName')
output_dir = nni_config.get_config('experimentConfig').get('trial').get('outputDir') output_dir = nni_config.get_config('experimentConfig').get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, experiment_id) hdfs_clean(host, user_name, output_dir, experiment_id)
...@@ -492,7 +491,7 @@ def experiment_clean(args): ...@@ -492,7 +491,7 @@ def experiment_clean(args):
experiment_config = Experiments() experiment_config = Experiments()
print_normal('removing metadata of experiment {0}'.format(experiment_id)) print_normal('removing metadata of experiment {0}'.format(experiment_id))
experiment_config.remove_experiment(experiment_id) experiment_config.remove_experiment(experiment_id)
print_normal('Done.') print_normal('Done.')
def get_platform_dir(config_content): def get_platform_dir(config_content):
'''get the dir list to be deleted''' '''get the dir list to be deleted'''
...@@ -505,8 +504,7 @@ def get_platform_dir(config_content): ...@@ -505,8 +504,7 @@ def get_platform_dir(config_content):
port = machine.get('port') port = machine.get('port')
dir_list.append(host + ':' + str(port) + '/tmp/nni') dir_list.append(host + ':' + str(port) + '/tmp/nni')
elif platform == 'pai': elif platform == 'pai':
pai_config = config_content.get('paiConfig') host = config_content.get('paiConfig').get('host')
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName') user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir') output_dir = config_content.get('trial').get('outputDir')
dir_list.append('server: {0}, path: {1}/nni'.format(host, user_name)) dir_list.append('server: {0}, path: {1}/nni'.format(host, user_name))
...@@ -529,17 +527,15 @@ def platform_clean(args): ...@@ -529,17 +527,15 @@ def platform_clean(args):
print_normal('platform {0} not supported.'.format(platform)) print_normal('platform {0} not supported.'.format(platform))
exit(0) exit(0)
update_experiment() update_experiment()
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
id_list = list(experiment_dict.keys())
dir_list = get_platform_dir(config_content) dir_list = get_platform_dir(config_content)
if not dir_list: if not dir_list:
print_normal('No folder of NNI caches is found.') print_normal('No folder of NNI caches is found.')
exit(1) exit(1)
while True: while True:
print_normal('This command will remove below folders of NNI caches. If other users are using experiments on below hosts, it will be broken.') print_normal('This command will remove below folders of NNI caches. If other users are using experiments' \
for dir in dir_list: ' on below hosts, it will be broken.')
print(' ' + dir) for value in dir_list:
print(' ' + value)
inputs = input('INFO: do you want to continue?[y/N]:') inputs = input('INFO: do you want to continue?[y/N]:')
if not inputs.lower() or inputs.lower() in ['n', 'no']: if not inputs.lower() or inputs.lower() in ['n', 'no']:
exit(0) exit(0)
...@@ -549,11 +545,9 @@ def platform_clean(args): ...@@ -549,11 +545,9 @@ def platform_clean(args):
break break
if platform == 'remote': if platform == 'remote':
machine_list = config_content.get('machineList') machine_list = config_content.get('machineList')
for machine in machine_list: remote_clean(machine_list, None)
remote_clean(machine_list, None)
elif platform == 'pai': elif platform == 'pai':
pai_config = config_content.get('paiConfig') host = config_content.get('paiConfig').get('host')
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName') user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir') output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, None) hdfs_clean(host, user_name, output_dir, None)
...@@ -618,7 +612,8 @@ def show_experiment_info(): ...@@ -618,7 +612,8 @@ def show_experiment_info():
return return
for key in experiment_id_list: for key in experiment_id_list:
print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \ print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \
experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))) experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], \
get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])))
print(TRIAL_MONITOR_HEAD) print(TRIAL_MONITOR_HEAD)
running, response = check_rest_server_quick(experiment_dict[key]['port']) running, response = check_rest_server_quick(experiment_dict[key]['port'])
if running: if running:
...@@ -627,7 +622,8 @@ def show_experiment_info(): ...@@ -627,7 +622,8 @@ def show_experiment_info():
content = json.loads(response.text) content = json.loads(response.text)
for index, value in enumerate(content): for index, value in enumerate(content):
content[index] = convert_time_stamp_to_date(value) content[index] = convert_time_stamp_to_date(value)
print(TRIAL_MONITOR_CONTENT % (content[index].get('id'), content[index].get('startTime'), content[index].get('endTime'), content[index].get('status'))) print(TRIAL_MONITOR_CONTENT % (content[index].get('id'), content[index].get('startTime'), \
content[index].get('endTime'), content[index].get('status')))
print(TRIAL_MONITOR_TAIL) print(TRIAL_MONITOR_TAIL)
def monitor_experiment(args): def monitor_experiment(args):
......
...@@ -18,12 +18,10 @@ ...@@ -18,12 +18,10 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import nni
import os import os
import sys import nni
from subprocess import call
from .constants import PACKAGE_REQUIREMENTS from .constants import PACKAGE_REQUIREMENTS
from .common_utils import print_normal, print_error from .common_utils import print_error
from .command_utils import install_requirements_command from .command_utils import install_requirements_command
def process_install(package_name): def process_install(package_name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment