Unverified Commit b99e2683 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Migration of NAS tests (#4933)

parent c0239e9d
...@@ -97,7 +97,7 @@ def _dataset_factory(dataset_type, subset=20): ...@@ -97,7 +97,7 @@ def _dataset_factory(dataset_type, subset=20):
if dataset_type == 'cifar10': if dataset_type == 'cifar10':
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_dataset = nni.trace(CIFAR10)( train_dataset = nni.trace(CIFAR10)(
'../data/cifar10', 'data/cifar10',
train=True, train=True,
transform=transforms.Compose([ transform=transforms.Compose([
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
...@@ -106,7 +106,7 @@ def _dataset_factory(dataset_type, subset=20): ...@@ -106,7 +106,7 @@ def _dataset_factory(dataset_type, subset=20):
normalize, normalize,
])) ]))
valid_dataset = nni.trace(CIFAR10)( valid_dataset = nni.trace(CIFAR10)(
'../data/cifar10', 'data/cifar10',
train=False, train=False,
transform=transforms.Compose([ transform=transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
...@@ -115,7 +115,7 @@ def _dataset_factory(dataset_type, subset=20): ...@@ -115,7 +115,7 @@ def _dataset_factory(dataset_type, subset=20):
elif dataset_type == 'imagenet': elif dataset_type == 'imagenet':
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dataset = nni.trace(ImageNet)( train_dataset = nni.trace(ImageNet)(
'../data/imagenet', 'data/imagenet',
split='val', # no train data available in tests split='val', # no train data available in tests
transform=transforms.Compose([ transform=transforms.Compose([
transforms.RandomResizedCrop(224), transforms.RandomResizedCrop(224),
...@@ -124,7 +124,7 @@ def _dataset_factory(dataset_type, subset=20): ...@@ -124,7 +124,7 @@ def _dataset_factory(dataset_type, subset=20):
normalize, normalize,
])) ]))
valid_dataset = nni.trace(ImageNet)( valid_dataset = nni.trace(ImageNet)(
'../data/imagenet', 'data/imagenet',
split='val', split='val',
transform=transforms.Compose([ transform=transforms.Compose([
transforms.Resize(256), transforms.Resize(256),
......
#!/bin/bash
set -e
CWD=${PWD}
## Export certain environment variables for unittest code to work
export COVERAGE_PROCESS_START=${CWD}/.coveragerc
export COVERAGE_DATA_FILE=${CWD}/coverage/data
export COVERAGE_HTML_DIR=${CWD}/coverhtml
rm ${COVERAGE_DATA_FILE}*
rm -rf ${COVERAGE_HTML_DIR}/*
mkdir ${CWD}/coverage
mkdir ${COVERAGE_HTML_DIR}
## ------Run integration test------
echo "===========================Testing: integration test==========================="
coverage run sdk_test.py
coverage combine
coverage html
#!/bin/bash
set -e
CWD=${PWD}
echo ""
echo "===========================Testing: NAS==========================="
EXAMPLE_DIR=${CWD}/../examples/nas
RETIARII_TEST_DIR=${CWD}/retiarii_test
cd $RETIARII_TEST_DIR/naive
for net in "simple" "complex"; do
for exec in "python" "graph"; do
echo "testing multi-trial example on ${net}, ${exec}..."
python3 search.py --net $net --exec $exec
done
done
echo "testing darts..."
cd $EXAMPLE_DIR/oneshot/darts
python3 search.py --epochs 1 --channels 2 --layers 4
python3 retrain.py --arc-checkpoint ./checkpoint.json --layers 4 --epochs 1
echo "testing enas..."
cd $EXAMPLE_DIR/oneshot/enas
python3 search.py --search-for macro --epochs 1
python3 search.py --search-for micro --epochs 1
#disabled for now
#echo "testing naive..."
#cd $EXAMPLE_DIR/naive
#python3 train.py
#echo "testing pdarts..."
#cd $EXAMPLE_DIR/legacy/pdarts
#python3 search.py --epochs 1 --channels 4 --nodes 2 --log-frequency 10 --add_layers 0 --add_layers 1 --dropped_ops 3 --dropped_ops 3
import os
import sys
import nni
import pytorch_lightning
import pytest
import torch
import torch.nn.functional as F
import nni.retiarii.nn.pytorch as nn
import nni.retiarii.evaluator.pytorch.lightning as pl
from nni.retiarii import strategy, model_wrapper
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from torchvision import transforms
from torchvision.datasets import MNIST
pytestmark = pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs')
def nas_experiment_trial_params(rootpath):
params = {}
if sys.platform == 'win32':
params['envs'] = f'set PYTHONPATH={rootpath} && '
else:
params['envs'] = f'PYTHONPATH={rootpath}:$PYTHONPATH'
return params
def ensure_success(exp: RetiariiExperiment):
# check experiment directory exists
exp_dir = os.path.join(
exp.config.canonical_copy().experiment_working_directory,
exp.id
)
assert os.path.exists(exp_dir) and os.path.exists(os.path.join(exp_dir, 'trials'))
# check job status
job_stats = exp.get_job_statistics()
if not (len(job_stats) == 1 and job_stats[0]['trialJobStatus'] == 'SUCCEEDED'):
print('Experiment jobs did not all succeed. Status is:', job_stats, file=sys.stderr)
print('Trying to fetch trial logs.', file=sys.stderr)
for root, _, files in os.walk(os.path.join(exp_dir, 'trials')):
for file in files:
fpath = os.path.join(root, file)
print('=' * 10 + ' ' + fpath + ' ' + '=' * 10, file=sys.stderr)
print(open(fpath).read(), file=sys.stderr)
raise RuntimeError('Experiment jobs did not all succeed.')
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
channels = nn.ValueChoice([4, 6, 8])
self.conv1 = nn.Conv2d(1, channels, 5)
self.pool1 = nn.LayerChoice([
nn.MaxPool2d((2, 2)), nn.AvgPool2d((2, 2))
])
self.conv2 = nn.Conv2d(channels, 16, 5)
self.pool2 = nn.LayerChoice([
nn.MaxPool2d(2), nn.AvgPool2d(2), nn.Conv2d(16, 16, 2, 2)
])
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
self.fc2 = nn.Linear(120, 84)
self.fcplus = nn.Linear(84, 84)
self.shortcut = nn.InputChoice(2, 1)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
print(x.shape)
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.shortcut([x, self.fcplus(x)])
x = self.fc3(x)
return x
def get_mnist_evaluator():
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = nni.trace(MNIST)('data/mnist', download=True, train=True, transform=transform)
train_loader = pl.DataLoader(train_dataset, 64)
valid_dataset = nni.trace(MNIST)('data/mnist', download=True, train=False, transform=transform)
valid_loader = pl.DataLoader(valid_dataset, 64)
return pl.Classification(
train_dataloader=train_loader, val_dataloaders=valid_loader,
limit_train_batches=20,
limit_val_batches=20,
max_epochs=1
)
def test_multitrial_experiment(pytestconfig):
base_model = Net()
evaluator = get_mnist_evaluator()
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config._trial_command_params = nas_experiment_trial_params(pytestconfig.rootpath)
exp.run(exp_config)
ensure_success(exp)
assert isinstance(exp.export_top_models()[0], dict)
exp.stop()
def test_oneshot_experiment():
base_model = Net()
evaluator = get_mnist_evaluator()
search_strategy = strategy.RandomOneShot()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_config = RetiariiExeConfig()
exp_config.execution_engine = 'oneshot'
exp.run(exp_config)
assert isinstance(exp.export_top_models()[0], dict)
...@@ -6,13 +6,21 @@ from nni.retiarii import * ...@@ -6,13 +6,21 @@ from nni.retiarii import *
# FIXME # FIXME
import nni.retiarii.debug_configs import nni.retiarii.debug_configs
nni.retiarii.debug_configs.framework = 'tensorflow' original_framework = nni.retiarii.debug_configs.framework
max_pool = Operation.new('MaxPool2D', {'pool_size': 2}) max_pool = Operation.new('MaxPool2D', {'pool_size': 2})
avg_pool = Operation.new('AveragePooling2D', {'pool_size': 2}) avg_pool = Operation.new('AveragePooling2D', {'pool_size': 2})
global_pool = Operation.new('GlobalAveragePooling2D') global_pool = Operation.new('GlobalAveragePooling2D')
def setup_module(module):
nni.retiarii.debug_configs.framework = 'tensorflow'
def teardown_module(module):
nni.retiarii.debug_configs.framework = original_framework
class DebugSampler(Sampler): class DebugSampler(Sampler):
def __init__(self): def __init__(self):
self.iteration = 0 self.iteration = 0
...@@ -79,5 +87,7 @@ def _get_pools(model): ...@@ -79,5 +87,7 @@ def _get_pools(model):
if __name__ == '__main__': if __name__ == '__main__':
setup_module(None)
test_dry_run() test_dry_run()
test_mutation() test_mutation()
teardown_module(None)
import argparse
import os
import sys
import pytorch_lightning as pl
import pytest
from subprocess import Popen
from nni.retiarii import strategy
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from .test_oneshot import _mnist_net
pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_multi_trial():
evaluator_kwargs = {
'max_epochs': 1
}
to_test = [
# (model, evaluator)
_mnist_net('simple', evaluator_kwargs),
_mnist_net('simple_value_choice', evaluator_kwargs),
_mnist_net('value_choice', evaluator_kwargs),
_mnist_net('repeat', evaluator_kwargs),
_mnist_net('custom_op', evaluator_kwargs),
]
for base_model, evaluator in to_test:
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_unittest'
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8080)
assert isinstance(exp.export_top_models()[0], dict)
exp.stop()
python_script = """
from nni.retiarii import strategy
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from test_oneshot import _mnist_net
base_model, evaluator = _mnist_net('simple', {'max_epochs': 1})
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_unittest'
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8080)
assert isinstance(exp.export_top_models()[0], dict)
"""
@pytest.mark.timeout(600)
def test_exp_exit_without_stop():
script_name = 'tmp_multi_trial.py'
with open(script_name, 'w') as f:
f.write(python_script)
proc = Popen([sys.executable, script_name])
proc.wait()
os.remove(script_name)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, default='all', metavar='E',
help='experiment to run, default = all')
args = parser.parse_args()
if args.exp == 'all':
test_multi_trial()
test_exp_exit_without_stop()
else:
globals()[f'test_{args.exp}']()
...@@ -21,14 +21,11 @@ def process_normal(): ...@@ -21,14 +21,11 @@ def process_normal():
def process_kill_slow(kill_time=2): def process_kill_slow(kill_time=2):
def handler_stop_signals(signum, frame): def handler_stop_signals(signum, frame):
print('debug proceess kill: signal received')
time.sleep(kill_time) time.sleep(kill_time)
print('debug proceess kill: signal processed')
sys.exit(0) sys.exit(0)
signal.signal(signal.SIGINT, handler_stop_signals) signal.signal(signal.SIGINT, handler_stop_signals)
signal.signal(signal.SIGTERM, handler_stop_signals) signal.signal(signal.SIGTERM, handler_stop_signals)
print('debug process kill: sleep')
time.sleep(360) time.sleep(360)
......
...@@ -10,7 +10,7 @@ import tkill from 'tree-kill'; ...@@ -10,7 +10,7 @@ import tkill from 'tree-kill';
import { NNIError, NNIErrorNames } from 'common/errors'; import { NNIError, NNIErrorNames } from 'common/errors';
import { getExperimentId } from 'common/experimentStartupInfo'; import { getExperimentId } from 'common/experimentStartupInfo';
import { getLogger, Logger } from 'common/log'; import { getLogger, Logger } from 'common/log';
import { powershellString } from 'common/shellUtils'; import { powershellString, shellString } from 'common/shellUtils';
import { import {
HyperParameters, TrainingService, TrialJobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus TrialJobDetail, TrialJobMetric, TrialJobStatus
...@@ -413,17 +413,18 @@ class LocalTrainingService implements TrainingService { ...@@ -413,17 +413,18 @@ class LocalTrainingService implements TrainingService {
private getScript(workingDirectory: string): string[] { private getScript(workingDirectory: string): string[] {
const script: string[] = []; const script: string[] = [];
const escapedCommand = shellString(this.config.trialCommand);
if (process.platform === 'win32') { if (process.platform === 'win32') {
script.push(`$PSDefaultParameterValues = @{'Out-File:Encoding' = 'utf8'}`); script.push(`$PSDefaultParameterValues = @{'Out-File:Encoding' = 'utf8'}`);
script.push(`cd $env:NNI_CODE_DIR`); script.push(`cd $env:NNI_CODE_DIR`);
script.push( script.push(
`cmd.exe /c ${this.config.trialCommand} 1>${path.join(workingDirectory, 'stdout')} 2>${path.join(workingDirectory, 'stderr')}`, `cmd.exe /c ${escapedCommand} 1>${path.join(workingDirectory, 'stdout')} 2>${path.join(workingDirectory, 'stderr')}`,
`$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`, `$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`,
`$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`, `$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`,
`Write $LASTEXITCODE " " $NOW_DATE | Out-File "${path.join(workingDirectory, '.nni', 'state')}" -NoNewline -encoding utf8`); `Write $LASTEXITCODE " " $NOW_DATE | Out-File "${path.join(workingDirectory, '.nni', 'state')}" -NoNewline -encoding utf8`);
} else { } else {
script.push(`cd $NNI_CODE_DIR`); script.push(`cd $NNI_CODE_DIR`);
script.push(`eval ${this.config.trialCommand} 1>${path.join(workingDirectory, 'stdout')} 2>${path.join(workingDirectory, 'stderr')}`); script.push(`eval ${escapedCommand} 1>${path.join(workingDirectory, 'stdout')} 2>${path.join(workingDirectory, 'stderr')}`);
if (process.platform === 'darwin') { if (process.platform === 'darwin') {
// https://superuser.com/questions/599072/how-to-get-bash-execution-time-in-milliseconds-under-mac-os-x // https://superuser.com/questions/599072/how-to-get-bash-execution-time-in-milliseconds-under-mac-os-x
// Considering the worst case, write 999 to avoid negative duration // Considering the worst case, write 999 to avoid negative duration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment