Commit b210695f authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge branch 'master' into dev-pruner-dataparallel

parents c7d58033 fdfff50d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class MutableOp(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.conv = nn.Conv2d(3, 120, kernel_size, padding=kernel_size // 2)
self.nested_mutable = InputChoice(n_candidates=10)
def forward(self, x):
return self.conv(x)
class NestedSpace(nn.Module):
# this doesn't pass tests
def __init__(self, test_case):
super().__init__()
self.test_case = test_case
self.conv1 = LayerChoice([MutableOp(3), MutableOp(5)])
self.gap = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(120, 10)
def forward(self, x):
bs = x.size(0)
x = F.relu(self.conv1(x))
x = self.gap(x).view(bs, -1)
x = self.fc(x)
return x
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import importlib
import os
import sys
from unittest import TestCase, main
import torch
import torch.nn as nn
from nni.nas.pytorch.classic_nas import get_and_apply_next_architecture
from nni.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.enas import EnasMutator
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.random import RandomMutator
from nni.nas.pytorch.utils import _reset_global_mutable_counting
class NasTestCase(TestCase):
def setUp(self):
self.default_input_size = [3, 32, 32]
self.model_path = os.path.join(os.path.dirname(__file__), "models")
sys.path.append(self.model_path)
self.model_module = importlib.import_module("pytorch_models")
self.default_cls = [self.model_module.NaiveSearchSpace, self.model_module.SpaceWithMutableScope]
self.cuda_test = [0]
if torch.cuda.is_available():
self.cuda_test.append(1)
if torch.cuda.device_count() > 1:
self.cuda_test.append(torch.cuda.device_count())
def tearDown(self):
sys.path.remove(self.model_path)
def iterative_sample_and_forward(self, model, mutator=None, input_size=None, n_iters=20, test_backward=True,
use_cuda=False):
if input_size is None:
input_size = self.default_input_size
# support pytorch only
input_size = [8 if use_cuda else 2] + input_size # at least 2 samples to enable batch norm
for _ in range(n_iters):
for param in model.parameters():
param.grad = None
if mutator is not None:
mutator.reset()
x = torch.randn(input_size)
if use_cuda:
x = x.cuda()
y = torch.sum(model(x))
if test_backward:
y.backward()
def default_mutator_test_pipeline(self, mutator_cls):
for model_cls in self.default_cls:
for cuda_test in self.cuda_test:
_reset_global_mutable_counting()
model = model_cls(self)
mutator = mutator_cls(model)
if cuda_test:
model.cuda()
mutator.cuda()
if cuda_test > 1:
model = nn.DataParallel(model)
self.iterative_sample_and_forward(model, mutator, use_cuda=cuda_test)
_reset_global_mutable_counting()
model_fixed = model_cls(self)
if cuda_test:
model_fixed.cuda()
if cuda_test > 1:
model_fixed = nn.DataParallel(model_fixed)
with torch.no_grad():
arc = mutator.export()
apply_fixed_architecture(model_fixed, arc)
self.iterative_sample_and_forward(model_fixed, n_iters=1, use_cuda=cuda_test)
def test_random_mutator(self):
self.default_mutator_test_pipeline(RandomMutator)
def test_enas_mutator(self):
self.default_mutator_test_pipeline(EnasMutator)
def test_darts_mutator(self):
# DARTS doesn't support DataParallel. To be fixed.
self.cuda_test = [t for t in self.cuda_test if t <= 1]
self.default_mutator_test_pipeline(DartsMutator)
def test_apply_twice(self):
model = self.model_module.NaiveSearchSpace(self)
with self.assertRaises(RuntimeError):
for _ in range(2):
RandomMutator(model)
def test_nested_space(self):
model = self.model_module.NestedSpace(self)
with self.assertRaises(RuntimeError):
RandomMutator(model)
def test_classic_nas(self):
for model_cls in self.default_cls:
model = model_cls(self)
get_and_apply_next_architecture(model)
self.iterative_sample_and_forward(model)
if __name__ == '__main__':
main()
......@@ -29,6 +29,12 @@ def gen_new_config(config_file, training_service='local'):
config['trial'].pop('command')
if 'gpuNum' in config['trial']:
config['trial'].pop('gpuNum')
if training_service == 'frameworkcontroller':
it_config[training_service]['trial']['taskRoles'][0]['command'] = config['trial']['command']
config['trial'].pop('command')
if 'gpuNum' in config['trial']:
config['trial'].pop('gpuNum')
deep_update(config, it_config['all'])
deep_update(config, it_config[training_service])
......@@ -106,7 +112,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=None)
parser.add_argument("--exclude", type=str, default=None)
parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai', 'kubeflow'], default='local')
parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'], default='local')
parser.add_argument("--local_gpu", action='store_true')
parser.add_argument("--preinstall", action='store_true')
args = parser.parse_args()
......
......@@ -42,6 +42,21 @@ def update_training_service_config(args):
config[args.ts]['kubeflowConfig']['azureStorage']['azureShare'] = args.azs_share
if args.nni_docker_image is not None:
config[args.ts]['trial']['worker']['image'] = args.nni_docker_image
elif args.ts == 'frameworkcontroller':
if args.nfs_server is not None:
config[args.ts]['frameworkcontrollerConfig']['nfs']['server'] = args.nfs_server
if args.nfs_path is not None:
config[args.ts]['frameworkcontrollerConfig']['nfs']['path'] = args.nfs_path
if args.keyvault_vaultname is not None:
config[args.ts]['frameworkcontrollerConfig']['keyVault']['vaultName'] = args.keyvault_vaultname
if args.keyvault_name is not None:
config[args.ts]['frameworkcontrollerConfig']['keyVault']['name'] = args.keyvault_name
if args.azs_account is not None:
config[args.ts]['frameworkcontrollerConfig']['azureStorage']['accountName'] = args.azs_account
if args.azs_share is not None:
config[args.ts]['frameworkcontrollerConfig']['azureStorage']['azureShare'] = args.azs_share
if args.nni_docker_image is not None:
config[args.ts]['trial']['taskRoles'][0]['image'] = args.nni_docker_image
elif args.ts == 'remote':
if args.remote_user is not None:
config[args.ts]['machineList'][0]['username'] = args.remote_user
......@@ -69,7 +84,7 @@ def convert_command():
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ts", type=str, choices=['pai', 'kubeflow', 'remote', 'local'], default='pai')
parser.add_argument("--ts", type=str, choices=['pai', 'kubeflow', 'remote', 'local', 'frameworkcontroller'], default='pai')
parser.add_argument("--nni_docker_image", type=str)
parser.add_argument("--nni_manager_ip", type=str)
# args for PAI
......@@ -79,7 +94,7 @@ if __name__ == '__main__':
parser.add_argument("--data_dir", type=str)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--vc", type=str)
# args for kubeflow
# args for kubeflow and frameworkController
parser.add_argument("--nfs_server", type=str)
parser.add_argument("--nfs_path", type=str)
parser.add_argument("--keyvault_vaultname", type=str)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
jobs:
- job: 'integration_test_frameworkController'
timeoutInMinutes: 0
steps:
- script: python3 -m pip install --upgrade pip setuptools --user
displayName: 'Install python tools'
- script: |
cd deployment/pypi
echo 'building prerelease package...'
make build
ls $(Build.SourcesDirectory)/deployment/pypi/dist/
condition: eq( variables['build_docker_img'], 'true' )
displayName: 'build nni bdsit_wheel'
- script: |
source install.sh
displayName: 'Install nni toolkit via source code'
- script: |
sudo apt-get install swig -y
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
displayName: 'Install dependencies for integration tests in frameworkcontroller mode'
- script: |
if [ $(build_docker_img) = 'true' ]
then
cd deployment/pypi
docker login -u $(docker_hub_user) -p $(docker_hub_pwd)
echo 'updating docker file for installing nni from local...'
# update Dockerfile to install NNI in docker image from whl file built in last step
sed -ie 's/RUN python3 -m pip --no-cache-dir install nni/COPY .\/dist\/* .\nRUN python3 -m pip install nni-*.whl/' ../docker/Dockerfile
cat ../docker/Dockerfile
export IMG_TAG=`date -u +%y%m%d%H%M`
docker build -f ../docker/Dockerfile -t $(test_docker_img_name):$IMG_TAG .
docker push $(test_docker_img_name):$IMG_TAG
export TEST_IMG=$(test_docker_img_name):$IMG_TAG
cd ../../
else
export TEST_IMG=$(existing_docker_img)
fi
echo "TEST_IMG:$TEST_IMG"
cd test
python3 generate_ts_config.py --ts frameworkcontroller --keyvault_vaultname $(keyVault_vaultName) --keyvault_name $(keyVault_name) \
--azs_account $(azureStorage_accountName) --azs_share $(azureStorage_azureShare) --nni_docker_image $TEST_IMG --nni_manager_ip $(nni_manager_ip)
cat training_service.yml
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts frameworkcontroller --exclude multi_phase
displayName: 'integration test'
......@@ -8,7 +8,7 @@ jobs:
- script: |
python -m pip install scikit-learn==0.20.0 --user
python -m pip install keras==2.1.6 --user
python -m pip install https://download.pytorch.org/whl/cu90/torch-0.4.1-cp36-cp36m-win_amd64.whl --user
python -m pip install torch===1.2.0 torchvision===0.4.1 -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install torchvision --user
python -m pip install tensorflow-gpu==1.11.0 --user
displayName: 'Install dependencies for integration tests'
......
......@@ -24,6 +24,32 @@ kubeflow:
image:
trainingServicePlatform: kubeflow
frameworkcontroller:
maxExecDuration: 15m
nniManagerIp:
frameworkcontrollerConfig:
serviceAccountName: frameworkbarrier
storage: azureStorage
keyVault:
vaultName:
name:
azureStorage:
accountName:
azureShare:
trial:
taskRoles:
- name: worker
taskNum: 1
command:
gpuNum: 1
cpuNum: 1
memoryMB: 8192
image:
frameworkAttemptCompletionPolicy:
minFailedTaskCount: 1
minSucceededTaskCount: 1
trainingServicePlatform: frameworkcontroller
local:
trainingServicePlatform: local
pai:
......
......@@ -271,16 +271,17 @@ pai_yarn_config_schema = {
pai_trial_schema = {
'trial':{
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str),
'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'),
'containerNFSMountPath': setType('containerNFSMountPath', str),
'paiStoragePlugin': setType('paiStoragePlugin', str)
'command': setType('command', str),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
Optional('memoryMB'): setType('memoryMB', int),
Optional('image'): setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str),
Optional('paiStoragePlugin'): setType('paiStoragePlugin', str),
Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath')
}
}
......@@ -407,15 +408,8 @@ frameworkcontroller_config_schema = {
}
machine_list_schema = {
Optional('machineList'):[Or({
'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535),
'username': setType('username', str),
'passwd': setType('passwd', str),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
Optional('useActiveGpu'): setType('useActiveGpu', bool)
}, {
Optional('machineList'):[Or(
{
'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535),
'username': setType('username', str),
......@@ -424,6 +418,15 @@ machine_list_schema = {
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
Optional('useActiveGpu'): setType('useActiveGpu', bool)
},
{
'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535),
'username': setType('username', str),
'passwd': setType('passwd', str),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int),
Optional('useActiveGpu'): setType('useActiveGpu', bool)
})]
}
......
......@@ -9,7 +9,7 @@ import random
import site
import time
import tempfile
from subprocess import Popen, check_call, CalledProcessError
from subprocess import Popen, check_call, CalledProcessError, PIPE, STDOUT
from nni_annotation import expand_annotations, generate_search_space
from nni.constants import ModuleName, AdvisorModuleName
from .launcher_utils import validate_all_content
......@@ -20,7 +20,7 @@ from .common_utils import get_yml_content, get_json_content, print_error, print_
detect_port, get_user, get_python_dir
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, PACKAGE_REQUIREMENTS
from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment, set_monitor
from .nnictl_utils import update_experiment
def get_log_path(config_file_name):
'''generate stdout and stderr log path'''
......@@ -78,17 +78,17 @@ def get_nni_installation_path():
print_error('Fail to find nni under python library')
exit(1)
def start_rest_server(port, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None):
def start_rest_server(args, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None):
'''Run nni manager process'''
if detect_port(port):
if detect_port(args.port):
print_error('Port %s is used by another process, please reset the port!\n' \
'You could use \'nnictl create --help\' to get help information' % port)
'You could use \'nnictl create --help\' to get help information' % args.port)
exit(1)
if (platform != 'local') and detect_port(int(port) + 1):
if (platform != 'local') and detect_port(int(args.port) + 1):
print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
'You could set another port to start experiment!\n' \
'You could use \'nnictl create --help\' to get help information' % ((int(port) + 1), (int(port) + 1)))
'You could use \'nnictl create --help\' to get help information' % ((int(args.port) + 1), (int(args.port) + 1)))
exit(1)
print_normal('Starting restful server...')
......@@ -99,7 +99,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
node_command = 'node'
if sys.platform == 'win32':
node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
cmds = [node_command, entry_file, '--port', str(port), '--mode', platform]
cmds = [node_command, entry_file, '--port', str(args.port), '--mode', platform]
if mode == 'view':
cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true']
......@@ -111,6 +111,8 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
cmds += ['--log_level', log_level]
if mode in ['resume', 'view']:
cmds += ['--experiment_id', experiment_id]
if args.foreground:
cmds += ['--foreground', 'true']
stdout_full_path, stderr_full_path = get_log_path(config_file_name)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
......@@ -120,9 +122,15 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
stderr_file.write(log_header)
if sys.platform == 'win32':
from subprocess import CREATE_NEW_PROCESS_GROUP
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
if args.foreground:
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=STDOUT, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
if args.foreground:
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
return process, str(time_now)
def set_trial_config(experiment_config, port, config_file_name):
......@@ -424,7 +432,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
log_level = 'debug'
# start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
rest_process, start_time = start_rest_server(args, experiment_config['trainingServicePlatform'], \
mode, config_file_name, experiment_id, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation
......@@ -493,8 +501,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
experiment_config['experimentName'])
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if args.watch:
set_monitor(True, 3, args.port, rest_process.pid)
if args.foreground:
try:
while True:
log_content = rest_process.stdout.readline().strip().decode('utf-8')
print(log_content)
except KeyboardInterrupt:
kill_command(rest_process.pid)
print_normal('Stopping experiment...')
def create_experiment(args):
'''start a new experiment'''
......
......@@ -7,7 +7,7 @@ from schema import SchemaError
from schema import Schema
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, PAI_YARN_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\
FRAMEWORKCONTROLLER_CONFIG_SCHEMA, tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from .common_utils import print_error, print_warning, print_normal
from .common_utils import print_error, print_warning, print_normal, get_yml_content
def expand_path(experiment_config, key):
'''Change '~' to user home directory'''
......@@ -63,6 +63,8 @@ def parse_path(experiment_config, config_path):
if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])):
expand_path(experiment_config['machineList'][index], 'sshKeyPath')
if experiment_config['trial'].get('paiConfigPath'):
expand_path(experiment_config['trial'], 'paiConfigPath')
#if users use relative path, convert it to absolute path
root_path = os.path.dirname(config_path)
......@@ -94,6 +96,8 @@ def parse_path(experiment_config, config_path):
if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])):
parse_relative_path(root_path, experiment_config['machineList'][index], 'sshKeyPath')
if experiment_config['trial'].get('paiConfigPath'):
parse_relative_path(root_path, experiment_config['trial'], 'paiConfigPath')
def validate_search_space_content(experiment_config):
'''Validate searchspace content,
......@@ -254,6 +258,45 @@ def validate_machine_list(experiment_config):
print_error('Please set machineList!')
exit(1)
def validate_pai_config_path(experiment_config):
'''validate paiConfigPath field'''
if experiment_config.get('trainingServicePlatform') == 'pai':
if experiment_config.get('trial', {}).get('paiConfigPath'):
# validate the file format of paiConfigPath, ensure it is yaml format
pai_config = get_yml_content(experiment_config['trial']['paiConfigPath'])
if experiment_config['trial'].get('image') is None:
if pai_config.get('prerequisites', [{}])[0].get('uri') is None:
print_error('Please set image field, or set image uri in your own paiConfig!')
exit(1)
experiment_config['trial']['image'] = pai_config['prerequisites'][0]['uri']
if experiment_config['trial'].get('gpuNum') is None:
if pai_config.get('taskRoles', {}).get('taskrole', {}).get('resourcePerInstance', {}).get('gpu') is None:
print_error('Please set gpuNum field, or set resourcePerInstance gpu in your own paiConfig!')
exit(1)
experiment_config['trial']['gpuNum'] = pai_config['taskRoles']['taskrole']['resourcePerInstance']['gpu']
if experiment_config['trial'].get('cpuNum') is None:
if pai_config.get('taskRoles', {}).get('taskrole', {}).get('resourcePerInstance', {}).get('cpu') is None:
print_error('Please set cpuNum field, or set resourcePerInstance cpu in your own paiConfig!')
exit(1)
experiment_config['trial']['cpuNum'] = pai_config['taskRoles']['taskrole']['resourcePerInstance']['cpu']
if experiment_config['trial'].get('memoryMB') is None:
if pai_config.get('taskRoles', {}).get('taskrole', {}).get('resourcePerInstance', {}).get('memoryMB', {}) is None:
print_error('Please set memoryMB field, or set resourcePerInstance memoryMB in your own paiConfig!')
exit(1)
experiment_config['trial']['memoryMB'] = pai_config['taskRoles']['taskrole']['resourcePerInstance']['memoryMB']
if experiment_config['trial'].get('paiStoragePlugin') is None:
if pai_config.get('extras', {}).get('com.microsoft.pai.runtimeplugin', [{}])[0].get('plugin') is None:
print_error('Please set paiStoragePlugin field, or set plugin in your own paiConfig!')
exit(1)
experiment_config['trial']['paiStoragePlugin'] = pai_config['extras']['com.microsoft.pai.runtimeplugin'][0]['plugin']
else:
pai_trial_fields_required_list = ['image', 'gpuNum', 'cpuNum', 'memoryMB', 'paiStoragePlugin']
for trial_field in pai_trial_fields_required_list:
if experiment_config['trial'].get(trial_field) is None:
print_error('Please set {0} in trial configuration,\
or set additional pai configuration file path in paiConfigPath!'.format(trial_field))
exit(1)
def validate_pai_trial_conifg(experiment_config):
'''validate the trial config in pai platform'''
if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']:
......@@ -269,6 +312,7 @@ def validate_pai_trial_conifg(experiment_config):
print_warning(warning_information.format('dataDir'))
if experiment_config.get('trial').get('outputDir'):
print_warning(warning_information.format('outputDir'))
validate_pai_config_path(experiment_config)
def validate_all_content(experiment_config, config_path):
'''Validate whether experiment_config is valid'''
......
......@@ -51,7 +51,7 @@ def parse_args():
parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_start.add_argument('--watch', '-w', action='store_true', help=' set watch mode')
parser_start.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal')
parser_start.set_defaults(func=create_experiment)
# parse resume command
......@@ -59,7 +59,7 @@ def parse_args():
parser_resume.add_argument('id', nargs='?', help='The id of the experiment you want to resume')
parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_resume.add_argument('--watch', '-w', action='store_true', help=' set watch mode')
parser_resume.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal')
parser_resume.set_defaults(func=resume_experiment)
# parse view command
......
......@@ -403,11 +403,13 @@ def remote_clean(machine_list, experiment_id=None):
userName = machine.get('username')
host = machine.get('ip')
port = machine.get('port')
sshKeyPath = machine.get('sshKeyPath')
passphrase = machine.get('passphrase')
if experiment_id:
remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments', experiment_id])
else:
remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments'])
sftp = create_ssh_sftp_client(host, port, userName, passwd)
sftp = create_ssh_sftp_client(host, port, userName, passwd, sshKeyPath, passphrase)
print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir))
remove_remote_directory(sftp, remote_dir)
......
......@@ -30,12 +30,16 @@ def copy_remote_directory_to_local(sftp, remote_path, local_path):
except Exception:
pass
def create_ssh_sftp_client(host_ip, port, username, password):
def create_ssh_sftp_client(host_ip, port, username, password, ssh_key_path, passphrase):
'''create ssh client'''
try:
paramiko = check_environment()
conn = paramiko.Transport(host_ip, port)
conn.connect(username=username, password=password)
if ssh_key_path is not None:
ssh_key = paramiko.RSAKey.from_private_key_file(ssh_key_path, password=passphrase)
conn.connect(username=username, pkey=ssh_key)
else:
conn.connect(username=username, password=password)
sftp = paramiko.SFTPClient.from_transport(conn)
return sftp
except Exception as exception:
......
......@@ -37,12 +37,14 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list,
machine_dict = {}
local_path_list = []
for machine in machine_list:
machine_dict[machine['ip']] = {'port': machine['port'], 'passwd': machine['passwd'], 'username': machine['username']}
machine_dict[machine['ip']] = {'port': machine['port'], 'passwd': machine['passwd'], 'username': machine['username'],
'sshKeyPath': machine.get('sshKeyPath'), 'passphrase': machine.get('passphrase')}
for index, host in enumerate(host_list):
local_path = os.path.join(temp_nni_path, trial_content[index].get('id'))
local_path_list.append(local_path)
print_normal('Copying log data from %s to %s' % (host + ':' + path_list[index], local_path))
sftp = create_ssh_sftp_client(host, machine_dict[host]['port'], machine_dict[host]['username'], machine_dict[host]['passwd'])
sftp = create_ssh_sftp_client(host, machine_dict[host]['port'], machine_dict[host]['username'], machine_dict[host]['passwd'],
machine_dict[host]['sshKeyPath'], machine_dict[host]['passphrase'])
copy_remote_directory_to_local(sftp, path_list[index], local_path)
print_normal('Copy done!')
return local_path_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment