Unverified Commit 0fbaff6c authored by Chi Song's avatar Chi Song Committed by GitHub
Browse files

Remove duplicate data under /tmp folder, and other small changes. (#2484)

parent 92b5fa14
......@@ -109,8 +109,7 @@ def main(args):
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
#data_dir = os.path.join(args['data_dir'], nni.get_trial_id())
data_dir = os.path.join(args['data_dir'], 'data')
data_dir = args['data_dir']
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True,
......
......@@ -88,7 +88,7 @@ def main(args):
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
data_dir = os.path.join(args['data_dir'], nni.get_trial_id())
data_dir = args['data_dir']
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True,
......@@ -144,7 +144,7 @@ def get_params():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("--data_dir", type=str,
default='/tmp/pytorch/mnist/input_data', help="data directory")
default='./data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
......
......@@ -88,7 +88,7 @@ def main(args):
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
data_dir = os.path.join(args['data_dir'], nni.get_trial_id())
data_dir = args['data_dir']
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True,
......@@ -129,7 +129,7 @@ def get_params():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("--data_dir", type=str,
default='/tmp/pytorch/mnist/input_data', help="data directory")
default='./data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument("--batch_num", type=int, default=None)
......
......@@ -51,6 +51,8 @@ testCases:
- name: mnist-pytorch
configFile: test/config/examples/mnist-pytorch.yml
# download data first, to prevent concurrent issue.
launchCommand: python3 ../examples/trials/mnist-pytorch/mnist.py --epochs 1 --batch_num 0 --data_dir ../examples/trials/mnist-pytorch/data && nnictl create --config $configFile --debug
- name: mnist-annotation
configFile: test/config/examples/mnist-annotation.yml
......@@ -84,7 +86,7 @@ testCases:
configFile: test/config/examples/classic-nas-pytorch.yml
# remove search space file
stopCommand: nnictl stop
onExitCommand: python3 -c 'import os; os.remove("config/examples/nni-nas-search-space.json")'
onExitCommand: python3 -c "import os; os.remove('config/examples/nni-nas-search-space.json')"
trainingService: local
#########################################################################
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
import os
import argparse
import subprocess
import time
import datetime
import shlex
import traceback
import json
import os
import subprocess
import sys
import time
import ruamel.yaml as yaml
from utils import get_experiment_status, get_yml_content, dump_yml_content, get_experiment_id, \
parse_max_duration_time, get_trial_stats, deep_update, print_trial_job_log, get_failed_trial_jobs, \
get_experiment_dir, print_experiment_log
from utils import GREEN, RED, CLEAR, STATUS_URL, TRIAL_JOBS_URL, EXPERIMENT_URL, REST_ENDPOINT, wait_for_port_available
import validators
from utils import (CLEAR, EXPERIMENT_URL, GREEN, RED, REST_ENDPOINT,
STATUS_URL, TRIAL_JOBS_URL, deep_update, dump_yml_content,
get_experiment_dir, get_experiment_id,
get_experiment_status, get_failed_trial_jobs,
get_trial_stats, get_yml_content, parse_max_duration_time,
print_experiment_log, print_trial_job_log,
wait_for_port_available)
it_variables = {}
def update_training_service_config(config, training_service):
it_ts_config = get_yml_content(os.path.join('config', 'training_service.yml'))
......@@ -39,6 +42,7 @@ def update_training_service_config(config, training_service):
deep_update(config, it_ts_config['all'])
deep_update(config, it_ts_config[training_service])
def prepare_config_file(test_case_config, it_config, args):
config_path = args.nni_source_dir + test_case_config['configFile']
test_yml_config = get_yml_content(config_path)
......@@ -63,6 +67,7 @@ def prepare_config_file(test_case_config, it_config, args):
return new_config_file
def run_test_case(test_case_config, it_config, args):
new_config_file = prepare_config_file(test_case_config, it_config, args)
# set configFile variable
......@@ -75,15 +80,16 @@ def run_test_case(test_case_config, it_config, args):
stop_command = get_command(test_case_config, 'stopCommand')
print('Stop command:', stop_command, flush=True)
if stop_command:
subprocess.run(shlex.split(stop_command))
subprocess.run(stop_command, shell=True)
exit_command = get_command(test_case_config, 'onExitCommand')
print('Exit command:', exit_command, flush=True)
if exit_command:
subprocess.run(shlex.split(exit_command), check=True)
subprocess.run(exit_command, shell=True, check=True)
# remove tmp config file
if os.path.exists(new_config_file):
os.remove(new_config_file)
def invoke_validator(test_case_config, nni_source_dir, training_service):
validator_config = test_case_config.get('validator')
if validator_config is None or validator_config.get('class') is None:
......@@ -100,10 +106,12 @@ def invoke_validator(test_case_config, nni_source_dir, training_service):
print_trial_job_log(training_service, TRIAL_JOBS_URL)
raise
def get_max_values(config_file):
experiment_config = get_yml_content(config_file)
return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum']
def get_command(test_case_config, commandKey):
command = test_case_config.get(commandKey)
if commandKey == 'launchCommand':
......@@ -121,11 +129,12 @@ def get_command(test_case_config, commandKey):
return command
def launch_test(config_file, training_service, test_case_config):
launch_command = get_command(test_case_config, 'launchCommand')
print('launch command: ', launch_command, flush=True)
proc = subprocess.run(shlex.split(launch_command))
proc = subprocess.run(launch_command, shell=True)
assert proc.returncode == 0, 'launch command failed with code %d' % proc.returncode
......@@ -176,6 +185,7 @@ def launch_test(config_file, training_service, test_case_config):
print_trial_job_log(training_service, TRIAL_JOBS_URL)
raise AssertionError('Failed to finish in maxExecDuration')
def case_excluded(name, excludes):
if name is None:
return False
......@@ -186,6 +196,7 @@ def case_excluded(name, excludes):
return True
return False
def case_included(name, cases):
assert cases is not None
for case in cases.split(','):
......@@ -193,9 +204,11 @@ def case_included(name, cases):
return True
return False
def match_platform(test_case_config):
return sys.platform in test_case_config['platform'].split(' ')
def match_training_service(test_case_config, cur_training_service):
case_ts = test_case_config['trainingService']
assert case_ts is not None
......@@ -205,6 +218,7 @@ def match_training_service(test_case_config, cur_training_service):
return True
return False
def run(args):
it_config = get_yml_content(args.config)
......@@ -227,7 +241,8 @@ def run(args):
continue
if not match_training_service(test_case_config, args.ts):
print('skipped {}, training service {} not match [{}]'.format(name, args.ts, test_case_config['trainingService']))
print('skipped {}, training service {} not match [{}]'.format(
name, args.ts, test_case_config['trainingService']))
continue
wait_for_port_available(8080, 30)
......@@ -244,7 +259,8 @@ if __name__ == '__main__':
parser.add_argument("--nni_source_dir", type=str, default='../')
parser.add_argument("--cases", type=str, default=None)
parser.add_argument("--exclude", type=str, default=None)
parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'], default='local')
parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai',
'kubeflow', 'frameworkcontroller'], default='local')
args = parser.parse_args()
run(args)
......@@ -11,7 +11,7 @@ jobs:
- script: |
set -e
python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torch==1.3.1 torchvision==0.4.1 -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install torch==1.3.1 torchvision==0.4.2 -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorflow-gpu==2.2.0 tensorflow-estimator==2.2.0 --force --user
python3 -m pip install keras==2.4.2 --user
sudo apt-get install swig -y
......
......@@ -11,7 +11,7 @@ jobs:
- script: |
set -e
python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torchvision==0.4.1 --user
python3 -m pip install torchvision==0.4.2 --user
python3 -m pip install torch==1.3.1 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow-gpu==1.15.2 tensorflow-estimator==1.15.1 --force --user
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment