Unverified Commit 0fbaff6c authored by Chi Song's avatar Chi Song Committed by GitHub
Browse files

Remove duplicate data under /tmp folder, and other small changes. (#2484)

parent 92b5fa14
...@@ -109,8 +109,7 @@ def main(args): ...@@ -109,8 +109,7 @@ def main(args):
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
#data_dir = os.path.join(args['data_dir'], nni.get_trial_id()) data_dir = args['data_dir']
data_dir = os.path.join(args['data_dir'], 'data')
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True, datasets.MNIST(data_dir, train=True, download=True,
......
...@@ -88,7 +88,7 @@ def main(args): ...@@ -88,7 +88,7 @@ def main(args):
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
data_dir = os.path.join(args['data_dir'], nni.get_trial_id()) data_dir = args['data_dir']
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True, datasets.MNIST(data_dir, train=True, download=True,
...@@ -144,7 +144,7 @@ def get_params(): ...@@ -144,7 +144,7 @@ def get_params():
# Training settings # Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("--data_dir", type=str, parser.add_argument("--data_dir", type=str,
default='/tmp/pytorch/mnist/input_data', help="data directory") default='./data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N', parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)') help='input batch size for training (default: 64)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
...@@ -180,4 +180,4 @@ if __name__ == '__main__': ...@@ -180,4 +180,4 @@ if __name__ == '__main__':
main(params) main(params)
except Exception as exception: except Exception as exception:
logger.exception(exception) logger.exception(exception)
raise raise
\ No newline at end of file
...@@ -88,7 +88,7 @@ def main(args): ...@@ -88,7 +88,7 @@ def main(args):
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
data_dir = os.path.join(args['data_dir'], nni.get_trial_id()) data_dir = args['data_dir']
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True, datasets.MNIST(data_dir, train=True, download=True,
...@@ -129,7 +129,7 @@ def get_params(): ...@@ -129,7 +129,7 @@ def get_params():
# Training settings # Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("--data_dir", type=str, parser.add_argument("--data_dir", type=str,
default='/tmp/pytorch/mnist/input_data', help="data directory") default='./data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N', parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)') help='input batch size for training (default: 64)')
parser.add_argument("--batch_num", type=int, default=None) parser.add_argument("--batch_num", type=int, default=None)
......
...@@ -51,6 +51,8 @@ testCases: ...@@ -51,6 +51,8 @@ testCases:
- name: mnist-pytorch - name: mnist-pytorch
configFile: test/config/examples/mnist-pytorch.yml configFile: test/config/examples/mnist-pytorch.yml
# download data first, to prevent concurrent issue.
launchCommand: python3 ../examples/trials/mnist-pytorch/mnist.py --epochs 1 --batch_num 0 --data_dir ../examples/trials/mnist-pytorch/data && nnictl create --config $configFile --debug
- name: mnist-annotation - name: mnist-annotation
configFile: test/config/examples/mnist-annotation.yml configFile: test/config/examples/mnist-annotation.yml
...@@ -84,7 +86,7 @@ testCases: ...@@ -84,7 +86,7 @@ testCases:
configFile: test/config/examples/classic-nas-pytorch.yml configFile: test/config/examples/classic-nas-pytorch.yml
# remove search space file # remove search space file
stopCommand: nnictl stop stopCommand: nnictl stop
onExitCommand: python3 -c 'import os; os.remove("config/examples/nni-nas-search-space.json")' onExitCommand: python3 -c "import os; os.remove('config/examples/nni-nas-search-space.json')"
trainingService: local trainingService: local
######################################################################### #########################################################################
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import sys
import os
import argparse import argparse
import subprocess
import time
import datetime import datetime
import shlex
import traceback
import json import json
import os
import subprocess
import sys
import time
import ruamel.yaml as yaml import ruamel.yaml as yaml
from utils import get_experiment_status, get_yml_content, dump_yml_content, get_experiment_id, \
parse_max_duration_time, get_trial_stats, deep_update, print_trial_job_log, get_failed_trial_jobs, \
get_experiment_dir, print_experiment_log
from utils import GREEN, RED, CLEAR, STATUS_URL, TRIAL_JOBS_URL, EXPERIMENT_URL, REST_ENDPOINT, wait_for_port_available
import validators import validators
from utils import (CLEAR, EXPERIMENT_URL, GREEN, RED, REST_ENDPOINT,
STATUS_URL, TRIAL_JOBS_URL, deep_update, dump_yml_content,
get_experiment_dir, get_experiment_id,
get_experiment_status, get_failed_trial_jobs,
get_trial_stats, get_yml_content, parse_max_duration_time,
print_experiment_log, print_trial_job_log,
wait_for_port_available)
it_variables = {} it_variables = {}
def update_training_service_config(config, training_service): def update_training_service_config(config, training_service):
it_ts_config = get_yml_content(os.path.join('config', 'training_service.yml')) it_ts_config = get_yml_content(os.path.join('config', 'training_service.yml'))
...@@ -39,6 +42,7 @@ def update_training_service_config(config, training_service): ...@@ -39,6 +42,7 @@ def update_training_service_config(config, training_service):
deep_update(config, it_ts_config['all']) deep_update(config, it_ts_config['all'])
deep_update(config, it_ts_config[training_service]) deep_update(config, it_ts_config[training_service])
def prepare_config_file(test_case_config, it_config, args): def prepare_config_file(test_case_config, it_config, args):
config_path = args.nni_source_dir + test_case_config['configFile'] config_path = args.nni_source_dir + test_case_config['configFile']
test_yml_config = get_yml_content(config_path) test_yml_config = get_yml_content(config_path)
...@@ -63,6 +67,7 @@ def prepare_config_file(test_case_config, it_config, args): ...@@ -63,6 +67,7 @@ def prepare_config_file(test_case_config, it_config, args):
return new_config_file return new_config_file
def run_test_case(test_case_config, it_config, args): def run_test_case(test_case_config, it_config, args):
new_config_file = prepare_config_file(test_case_config, it_config, args) new_config_file = prepare_config_file(test_case_config, it_config, args)
# set configFile variable # set configFile variable
...@@ -75,15 +80,16 @@ def run_test_case(test_case_config, it_config, args): ...@@ -75,15 +80,16 @@ def run_test_case(test_case_config, it_config, args):
stop_command = get_command(test_case_config, 'stopCommand') stop_command = get_command(test_case_config, 'stopCommand')
print('Stop command:', stop_command, flush=True) print('Stop command:', stop_command, flush=True)
if stop_command: if stop_command:
subprocess.run(shlex.split(stop_command)) subprocess.run(stop_command, shell=True)
exit_command = get_command(test_case_config, 'onExitCommand') exit_command = get_command(test_case_config, 'onExitCommand')
print('Exit command:', exit_command, flush=True) print('Exit command:', exit_command, flush=True)
if exit_command: if exit_command:
subprocess.run(shlex.split(exit_command), check=True) subprocess.run(exit_command, shell=True, check=True)
# remove tmp config file # remove tmp config file
if os.path.exists(new_config_file): if os.path.exists(new_config_file):
os.remove(new_config_file) os.remove(new_config_file)
def invoke_validator(test_case_config, nni_source_dir, training_service): def invoke_validator(test_case_config, nni_source_dir, training_service):
validator_config = test_case_config.get('validator') validator_config = test_case_config.get('validator')
if validator_config is None or validator_config.get('class') is None: if validator_config is None or validator_config.get('class') is None:
...@@ -100,10 +106,12 @@ def invoke_validator(test_case_config, nni_source_dir, training_service): ...@@ -100,10 +106,12 @@ def invoke_validator(test_case_config, nni_source_dir, training_service):
print_trial_job_log(training_service, TRIAL_JOBS_URL) print_trial_job_log(training_service, TRIAL_JOBS_URL)
raise raise
def get_max_values(config_file): def get_max_values(config_file):
experiment_config = get_yml_content(config_file) experiment_config = get_yml_content(config_file)
return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum'] return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum']
def get_command(test_case_config, commandKey): def get_command(test_case_config, commandKey):
command = test_case_config.get(commandKey) command = test_case_config.get(commandKey)
if commandKey == 'launchCommand': if commandKey == 'launchCommand':
...@@ -121,11 +129,12 @@ def get_command(test_case_config, commandKey): ...@@ -121,11 +129,12 @@ def get_command(test_case_config, commandKey):
return command return command
def launch_test(config_file, training_service, test_case_config): def launch_test(config_file, training_service, test_case_config):
launch_command = get_command(test_case_config, 'launchCommand') launch_command = get_command(test_case_config, 'launchCommand')
print('launch command: ', launch_command, flush=True) print('launch command: ', launch_command, flush=True)
proc = subprocess.run(shlex.split(launch_command)) proc = subprocess.run(launch_command, shell=True)
assert proc.returncode == 0, 'launch command failed with code %d' % proc.returncode assert proc.returncode == 0, 'launch command failed with code %d' % proc.returncode
...@@ -150,7 +159,7 @@ def launch_test(config_file, training_service, test_case_config): ...@@ -150,7 +159,7 @@ def launch_test(config_file, training_service, test_case_config):
experiment_id = get_experiment_id(EXPERIMENT_URL) experiment_id = get_experiment_id(EXPERIMENT_URL)
while True: while True:
waited_time = time.time() - bg_time waited_time = time.time() - bg_time
if waited_time > max_duration + 10: if waited_time > max_duration + 10:
print('waited: {}, max_duration: {}'.format(waited_time, max_duration)) print('waited: {}, max_duration: {}'.format(waited_time, max_duration))
break break
status = get_experiment_status(STATUS_URL) status = get_experiment_status(STATUS_URL)
...@@ -176,6 +185,7 @@ def launch_test(config_file, training_service, test_case_config): ...@@ -176,6 +185,7 @@ def launch_test(config_file, training_service, test_case_config):
print_trial_job_log(training_service, TRIAL_JOBS_URL) print_trial_job_log(training_service, TRIAL_JOBS_URL)
raise AssertionError('Failed to finish in maxExecDuration') raise AssertionError('Failed to finish in maxExecDuration')
def case_excluded(name, excludes): def case_excluded(name, excludes):
if name is None: if name is None:
return False return False
...@@ -186,6 +196,7 @@ def case_excluded(name, excludes): ...@@ -186,6 +196,7 @@ def case_excluded(name, excludes):
return True return True
return False return False
def case_included(name, cases): def case_included(name, cases):
assert cases is not None assert cases is not None
for case in cases.split(','): for case in cases.split(','):
...@@ -193,9 +204,11 @@ def case_included(name, cases): ...@@ -193,9 +204,11 @@ def case_included(name, cases):
return True return True
return False return False
def match_platform(test_case_config): def match_platform(test_case_config):
return sys.platform in test_case_config['platform'].split(' ') return sys.platform in test_case_config['platform'].split(' ')
def match_training_service(test_case_config, cur_training_service): def match_training_service(test_case_config, cur_training_service):
case_ts = test_case_config['trainingService'] case_ts = test_case_config['trainingService']
assert case_ts is not None assert case_ts is not None
...@@ -205,6 +218,7 @@ def match_training_service(test_case_config, cur_training_service): ...@@ -205,6 +218,7 @@ def match_training_service(test_case_config, cur_training_service):
return True return True
return False return False
def run(args): def run(args):
it_config = get_yml_content(args.config) it_config = get_yml_content(args.config)
...@@ -227,7 +241,8 @@ def run(args): ...@@ -227,7 +241,8 @@ def run(args):
continue continue
if not match_training_service(test_case_config, args.ts): if not match_training_service(test_case_config, args.ts):
print('skipped {}, training service {} not match [{}]'.format(name, args.ts, test_case_config['trainingService'])) print('skipped {}, training service {} not match [{}]'.format(
name, args.ts, test_case_config['trainingService']))
continue continue
wait_for_port_available(8080, 30) wait_for_port_available(8080, 30)
...@@ -244,7 +259,8 @@ if __name__ == '__main__': ...@@ -244,7 +259,8 @@ if __name__ == '__main__':
parser.add_argument("--nni_source_dir", type=str, default='../') parser.add_argument("--nni_source_dir", type=str, default='../')
parser.add_argument("--cases", type=str, default=None) parser.add_argument("--cases", type=str, default=None)
parser.add_argument("--exclude", type=str, default=None) parser.add_argument("--exclude", type=str, default=None)
parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'], default='local') parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai',
'kubeflow', 'frameworkcontroller'], default='local')
args = parser.parse_args() args = parser.parse_args()
run(args) run(args)
...@@ -11,7 +11,7 @@ jobs: ...@@ -11,7 +11,7 @@ jobs:
- script: | - script: |
set -e set -e
python3 -m pip install scikit-learn==0.20.0 --user python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torch==1.3.1 torchvision==0.4.1 -f https://download.pytorch.org/whl/torch_stable.html --user python3 -m pip install torch==1.3.1 torchvision==0.4.2 -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorflow-gpu==2.2.0 tensorflow-estimator==2.2.0 --force --user python3 -m pip install tensorflow-gpu==2.2.0 tensorflow-estimator==2.2.0 --force --user
python3 -m pip install keras==2.4.2 --user python3 -m pip install keras==2.4.2 --user
sudo apt-get install swig -y sudo apt-get install swig -y
......
...@@ -11,7 +11,7 @@ jobs: ...@@ -11,7 +11,7 @@ jobs:
- script: | - script: |
set -e set -e
python3 -m pip install scikit-learn==0.20.0 --user python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torchvision==0.4.1 --user python3 -m pip install torchvision==0.4.2 --user
python3 -m pip install torch==1.3.1 --user python3 -m pip install torch==1.3.1 --user
python3 -m pip install keras==2.1.6 --user python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow-gpu==1.15.2 tensorflow-estimator==1.15.1 --force --user python3 -m pip install tensorflow-gpu==1.15.2 tensorflow-estimator==1.15.1 --force --user
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment