Commit b40e3db7 authored by quzha's avatar quzha
Browse files

Merge branch 'master' of github.com:Microsoft/nni into dev-retiarii

parents efa4e31c 95f731e4
...@@ -28,6 +28,7 @@ logger = logging.getLogger('trial_keeper') ...@@ -28,6 +28,7 @@ logger = logging.getLogger('trial_keeper')
regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*') regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*')
_hdfs_client = None _hdfs_client = None
_trial_process = None
def get_hdfs_client(args): def get_hdfs_client(args):
...@@ -62,6 +63,7 @@ def get_hdfs_client(args): ...@@ -62,6 +63,7 @@ def get_hdfs_client(args):
def main_loop(args): def main_loop(args):
'''main loop logic for trial keeper''' '''main loop logic for trial keeper'''
global _trial_process
if not os.path.exists(LOG_DIR): if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR) os.makedirs(LOG_DIR)
...@@ -90,13 +92,13 @@ def main_loop(args): ...@@ -90,13 +92,13 @@ def main_loop(args):
# Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior # Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior
log_pipe_stdout = trial_syslogger_stdout.get_pipelog_reader() log_pipe_stdout = trial_syslogger_stdout.get_pipelog_reader()
process = Popen(args.trial_command, shell=True, stdout=log_pipe_stdout, stderr=log_pipe_stdout) _trial_process = Popen(args.trial_command, shell=True, stdout=log_pipe_stdout, stderr=log_pipe_stdout, preexec_fn=os.setsid)
nni_log(LogType.Info, 'Trial keeper spawns a subprocess (pid {0}) to run command: {1}'.format(process.pid, nni_log(LogType.Info, 'Trial keeper spawns a subprocess (pid {0}) to run command: {1}'.format(_trial_process.pid,
shlex.split( shlex.split(
args.trial_command))) args.trial_command)))
while True: while True:
retCode = process.poll() retCode = _trial_process.poll()
# child worker process exits and all stdout data is read # child worker process exits and all stdout data is read
if retCode is not None and log_pipe_stdout.set_process_exit() and log_pipe_stdout.is_read_completed == True: if retCode is not None and log_pipe_stdout.set_process_exit() and log_pipe_stdout.is_read_completed == True:
# In Windows, the retCode -1 is 4294967295. It's larger than c_long, and raise OverflowError. # In Windows, the retCode -1 is 4294967295. It's larger than c_long, and raise OverflowError.
...@@ -213,6 +215,20 @@ def fetch_parameter_file(args): ...@@ -213,6 +215,20 @@ def fetch_parameter_file(args):
fetch_file_thread.start() fetch_file_thread.start()
def _set_adaptdl_signal_handler():
import signal
global _trial_process
def _handler(signum, frame):
nni_log(LogType.Info, "RECEIVED SIGNAL {}".format(signum))
nni_log(LogType.Debug, "TRIAL PROCESS ID {}".format(_trial_process.pid))
if _trial_process and (signum == signal.SIGTERM or signum == signal.SIGINT):
os.killpg(os.getpgid(_trial_process.pid), signal.SIGINT)
os.waitpid(_trial_process.pid, 0)
exit(1)
signal.signal(signal.SIGTERM, _handler)
signal.signal(signal.SIGINT, _handler)
if __name__ == '__main__': if __name__ == '__main__':
'''NNI Trial Keeper main function''' '''NNI Trial Keeper main function'''
PARSER = argparse.ArgumentParser() PARSER = argparse.ArgumentParser()
...@@ -237,6 +253,8 @@ if __name__ == '__main__': ...@@ -237,6 +253,8 @@ if __name__ == '__main__':
try: try:
if NNI_PLATFORM == 'paiYarn' and is_multi_phase(): if NNI_PLATFORM == 'paiYarn' and is_multi_phase():
fetch_parameter_file(args) fetch_parameter_file(args)
if NNI_PLATFORM == 'adl':
_set_adaptdl_signal_handler()
main_loop(args) main_loop(args)
except SystemExit as se: except SystemExit as se:
nni_log(LogType.Info, 'NNI trial keeper exit with code {}'.format(se.code)) nni_log(LogType.Info, 'NNI trial keeper exit with code {}'.format(se.code))
......
...@@ -97,6 +97,21 @@ def get_sequence_id(): ...@@ -97,6 +97,21 @@ def get_sequence_id():
_intermediate_seq = 0 _intermediate_seq = 0
def overwrite_intermediate_seq(value):
"""
Overwrite intermediate sequence value.
Parameters
----------
value:
int
"""
assert isinstance(value, int)
global _intermediate_seq
_intermediate_seq = value
def report_intermediate_result(metric): def report_intermediate_result(metric):
""" """
Reports intermediate result to NNI. Reports intermediate result to NNI.
......
...@@ -34,7 +34,7 @@ jobs: ...@@ -34,7 +34,7 @@ jobs:
set -e set -e
sudo apt-get install -y pandoc sudo apt-get install -y pandoc
python3 -m pip install -U --upgrade pygments python3 -m pip install -U --upgrade pygments
python3 -m pip install -U torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html python3 -m pip install -U torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U tensorflow==2.3.1 python3 -m pip install -U tensorflow==2.3.1
python3 -m pip install -U keras==2.4.2 python3 -m pip install -U keras==2.4.2
python3 -m pip install -U gym onnx peewee thop python3 -m pip install -U gym onnx peewee thop
...@@ -96,7 +96,7 @@ jobs: ...@@ -96,7 +96,7 @@ jobs:
- script: | - script: |
set -e set -e
python3 -m pip install -U torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html python3 -m pip install -U torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -U tensorflow==1.15.2 python3 -m pip install -U tensorflow==1.15.2
python3 -m pip install -U keras==2.1.6 python3 -m pip install -U keras==2.1.6
python3 -m pip install -U gym onnx peewee python3 -m pip install -U gym onnx peewee
...@@ -131,12 +131,16 @@ jobs: ...@@ -131,12 +131,16 @@ jobs:
# This platform runs TypeScript unit test first. # This platform runs TypeScript unit test first.
steps: steps:
- task: UsePythonVersion@0
inputs:
versionSpec: 3.8
displayName: Configure Python
- script: | - script: |
set -e set -e
export PYTHON38_BIN_DIR=/usr/local/Cellar/python@3.8/`ls /usr/local/Cellar/python@3.8`/bin echo "##vso[task.setvariable variable=PATH]${PATH}:${HOME}/.local/bin"
echo "##vso[task.setvariable variable=PATH]${PYTHON38_BIN_DIR}:${HOME}/Library/Python/3.8/bin:${PATH}" python -m pip install -U --upgrade pip setuptools wheel
python3 -m pip install -U --upgrade pip setuptools python -m pip install -U pytest coverage
python3 -m pip install -U pytest coverage
displayName: 'Install Python tools' displayName: 'Install Python tools'
- script: | - script: |
...@@ -145,10 +149,9 @@ jobs: ...@@ -145,10 +149,9 @@ jobs:
- script: | - script: |
set -e set -e
cd ts/nni_manager export CI=true
yarn test (cd ts/nni_manager && yarn test)
cd ../nasui (cd ts/nasui && yarn test)
CI=true yarn test
displayName: 'TypeScript unit test' displayName: 'TypeScript unit test'
- script: | - script: |
...@@ -188,7 +191,7 @@ jobs: ...@@ -188,7 +191,7 @@ jobs:
displayName: 'Install Python tools' displayName: 'Install Python tools'
- script: | - script: |
python setup.py develop python setup.py develop --no-user
displayName: 'Install NNI' displayName: 'Install NNI'
- script: | - script: |
...@@ -201,16 +204,18 @@ jobs: ...@@ -201,16 +204,18 @@ jobs:
cd test cd test
python -m pytest ut python -m pytest ut
displayName: 'Python unit test' displayName: 'Python unit test'
continueOnError: true
- script: | - script: |
cd ts/nni_manager cd ts/nni_manager
yarn test yarn test
displayName: 'TypeScript unit test' displayName: 'TypeScript unit test'
continueOnError: true
- script: | - script: |
cd test cd test
python nni_test/nnitest/run_tests.py --config config/pr_tests.yml python nni_test/nnitest/run_tests.py --config config/pr_tests.yml
displayName: 'Simple integration test' displayName: 'Simple integration test'
continueOnError: true
trigger:
branches:
exclude: [ l10n_master ]
...@@ -72,7 +72,9 @@ dependencies = [ ...@@ -72,7 +72,9 @@ dependencies = [
'colorama', 'colorama',
'scikit-learn>=0.23.2', 'scikit-learn>=0.23.2',
'pkginfo', 'pkginfo',
'websockets' 'websockets',
'filelock',
'prettytable'
] ]
......
...@@ -35,8 +35,7 @@ class SimpleTuner(Tuner): ...@@ -35,8 +35,7 @@ class SimpleTuner(Tuner):
'checksum': None, 'checksum': None,
'path': '', 'path': '',
} }
_logger.info('generate parameter for father trial %s' % _logger.info('generate parameter for father trial %s', parameter_id)
parameter_id)
self.thread_lock.release() self.thread_lock.release()
return { return {
'prev_id': 0, 'prev_id': 0,
......
...@@ -18,7 +18,7 @@ class NaiveAssessor(Assessor): ...@@ -18,7 +18,7 @@ class NaiveAssessor(Assessor):
_logger.info('init') _logger.info('init')
def assess_trial(self, trial_job_id, trial_history): def assess_trial(self, trial_job_id, trial_history):
_logger.info('assess trial %s %s' % (trial_job_id, trial_history)) _logger.info('assess trial %s %s', trial_job_id, trial_history)
id_ = trial_history[0] id_ = trial_history[0]
if id_ in self._killed: if id_ in self._killed:
......
...@@ -21,17 +21,17 @@ class NaiveTuner(Tuner): ...@@ -21,17 +21,17 @@ class NaiveTuner(Tuner):
def generate_parameters(self, parameter_id, **kwargs): def generate_parameters(self, parameter_id, **kwargs):
self.cur += 1 self.cur += 1
_logger.info('generate parameters: %s' % self.cur) _logger.info('generate parameters: %s', self.cur)
return { 'x': self.cur } return { 'x': self.cur }
def receive_trial_result(self, parameter_id, parameters, value, **kwargs): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
reward = extract_scalar_reward(value) reward = extract_scalar_reward(value)
_logger.info('receive trial result: %s, %s, %s' % (parameter_id, parameters, reward)) _logger.info('receive trial result: %s, %s, %s', parameter_id, parameters, reward)
_result.write('%d %d\n' % (parameters['x'], reward)) _result.write('%d %d\n' % (parameters['x'], reward))
_result.flush() _result.flush()
def update_search_space(self, search_space): def update_search_space(self, search_space):
_logger.info('update_search_space: %s' % search_space) _logger.info('update_search_space: %s', search_space)
with open(os.path.join(_pwd, 'tuner_search_space.json'), 'w') as file_: with open(os.path.join(_pwd, 'tuner_search_space.json'), 'w') as file_:
json.dump(search_space, file_) json.dump(search_space, file_)
......
...@@ -11,6 +11,7 @@ advisor: ...@@ -11,6 +11,7 @@ advisor:
optimize_mode: maximize optimize_mode: maximize
R: 60 R: 60
eta: 3 eta: 3
exec_mode: parallelism
trial: trial:
codeDir: ../../../examples/trials/mnist-advisor codeDir: ../../../examples/trials/mnist-advisor
command: python3 mnist.py command: python3 mnist.py
......
...@@ -124,7 +124,7 @@ def print_file_content(filepath): ...@@ -124,7 +124,7 @@ def print_file_content(filepath):
def print_trial_job_log(training_service, trial_jobs_url): def print_trial_job_log(training_service, trial_jobs_url):
trial_jobs = get_trial_jobs(trial_jobs_url) trial_jobs = get_trial_jobs(trial_jobs_url)
for trial_job in trial_jobs: for trial_job in trial_jobs:
trial_log_dir = os.path.join(get_experiment_dir(EXPERIMENT_URL), 'trials', trial_job['id']) trial_log_dir = os.path.join(get_experiment_dir(EXPERIMENT_URL), 'trials', trial_job['trialJobId'])
log_files = ['stderr', 'trial.log'] if training_service == 'local' else ['stdout_log_collection.log'] log_files = ['stderr', 'trial.log'] if training_service == 'local' else ['stdout_log_collection.log']
for log_file in log_files: for log_file in log_files:
print_file_content(os.path.join(trial_log_dir, log_file)) print_file_content(os.path.join(trial_log_dir, log_file))
......
...@@ -12,6 +12,7 @@ import numpy as np ...@@ -12,6 +12,7 @@ import numpy as np
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
from nni.compression.pytorch.utils.shape_dependency import ChannelDependency from nni.compression.pytorch.utils.shape_dependency import ChannelDependency
from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict from nni.compression.pytorch.utils.mask_conflict import fix_mask_conflict
from nni.compression.pytorch.utils.counter import count_flops_params
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
prefix = 'analysis_test' prefix = 'analysis_test'
...@@ -60,7 +61,6 @@ channel_dependency_ground_truth = { ...@@ -60,7 +61,6 @@ channel_dependency_ground_truth = {
unittest.TestLoader.sortTestMethodsUsing = None unittest.TestLoader.sortTestMethodsUsing = None
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class AnalysisUtilsTest(TestCase): class AnalysisUtilsTest(TestCase):
@unittest.skipIf(torch.__version__ < "1.3.0", "not supported") @unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
def test_channel_dependency(self): def test_channel_dependency(self):
...@@ -138,5 +138,49 @@ class AnalysisUtilsTest(TestCase): ...@@ -138,5 +138,49 @@ class AnalysisUtilsTest(TestCase):
assert b_index1 == b_index2 assert b_index1 == b_index2
def test_flops_params(self):
class Model1(nn.Module):
def __init__(self):
super(Model1, self).__init__()
self.conv = nn.Conv2d(3, 5, 1, 1)
self.bn = nn.BatchNorm2d(5)
self.relu = nn.LeakyReLU()
self.linear = nn.Linear(20, 10)
self.upsample = nn.UpsamplingBilinear2d(size=2)
self.pool = nn.AdaptiveAvgPool2d((2, 2))
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.upsample(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
class Model2(nn.Module):
def __init__(self):
super(Model2, self).__init__()
self.conv = nn.Conv2d(3, 5, 1, 1)
self.conv2 = nn.Conv2d(5, 5, 1, 1)
def forward(self, x):
x = self.conv(x)
for _ in range(5):
x = self.conv2(x)
return x
flops, params, results = count_flops_params(Model1(), (1, 3, 2, 2), mode='full', verbose=False)
assert (flops, params) == (610, 240)
flops, params, results = count_flops_params(Model2(), (1, 3, 2, 2), verbose=False)
assert (flops, params) == (560, 50)
from torchvision.models import resnet50
flops, params, results = count_flops_params(resnet50(), (1, 3, 224, 224), verbose=False)
assert (flops, params) == (4089184256, 25503912)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -47,7 +47,6 @@ def generate_random_sparsity_v2(model): ...@@ -47,7 +47,6 @@ def generate_random_sparsity_v2(model):
return cfg_list return cfg_list
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class DependencyawareTest(TestCase): class DependencyawareTest(TestCase):
@unittest.skipIf(torch.__version__ < "1.3.0", "not supported") @unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
def test_dependency_aware_pruning(self): def test_dependency_aware_pruning(self):
......
...@@ -177,7 +177,6 @@ def channel_prune(model): ...@@ -177,7 +177,6 @@ def channel_prune(model):
pruner.compress() pruner.compress()
pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE) pruner.export_model(model_path=MODEL_FILE, mask_path=MASK_FILE)
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class SpeedupTestCase(TestCase): class SpeedupTestCase(TestCase):
def test_speedup_vgg16(self): def test_speedup_vgg16(self):
prune_model_l1(vgg16()) prune_model_l1(vgg16())
...@@ -218,11 +217,9 @@ class SpeedupTestCase(TestCase): ...@@ -218,11 +217,9 @@ class SpeedupTestCase(TestCase):
assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY) assert model.backbone2.conv2.out_channels == int(orig_model.backbone2.conv2.out_channels * SPARSITY)
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY) assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
# FIXME: # FIXME: This test case might fail randomly, no idea why
# This test case failed on macOS: # Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282
# https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=15658
@unittest.skipIf(sys.platform == 'darwin', 'Failed for unknown reason')
def test_speedup_integration(self): def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3', 'resnet50']: for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3', 'resnet50']:
kwargs = { kwargs = {
...@@ -251,7 +248,7 @@ class SpeedupTestCase(TestCase): ...@@ -251,7 +248,7 @@ class SpeedupTestCase(TestCase):
zero_bn_bias(net) zero_bn_bias(net)
zero_bn_bias(speedup_model) zero_bn_bias(speedup_model)
data = torch.ones(BATCH_SIZE, 3, 224, 224).to(device) data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
ms = ModelSpeedup(speedup_model, data, MASK_FILE) ms = ModelSpeedup(speedup_model, data, MASK_FILE)
ms.speedup_model() ms.speedup_model()
...@@ -281,7 +278,7 @@ class SpeedupTestCase(TestCase): ...@@ -281,7 +278,7 @@ class SpeedupTestCase(TestCase):
net.load_state_dict(state_dict) net.load_state_dict(state_dict)
net.eval() net.eval()
data = torch.randn(BATCH_SIZE, 3, 224, 224).to(device) data = torch.randn(BATCH_SIZE, 3, 128, 128).to(device)
ms = ModelSpeedup(net, data, MASK_FILE) ms = ModelSpeedup(net, data, MASK_FILE)
ms.speedup_model() ms.speedup_model()
ms.bound_model(data) ms.bound_model(data)
......
...@@ -151,12 +151,37 @@ prune_config = { ...@@ -151,12 +151,37 @@ prune_config = {
lambda model: validate_sparsity(model.conv1, 0.5, model.bias) lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
] ]
}, },
'autocompress': { 'autocompress_l1': {
'pruner_class': AutoCompressPruner, 'pruner_class': AutoCompressPruner,
'config_list': [{ 'config_list': [{
'sparsity': 0.5, 'sparsity': 0.5,
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}], }],
'base_algo': 'l1',
'trainer': lambda model, optimizer, criterion, epoch, callback : model,
'evaluator': lambda model: 0.9,
'dummy_input': torch.randn([64, 1, 28, 28]),
'validators': []
},
'autocompress_l2': {
'pruner_class': AutoCompressPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'base_algo': 'l2',
'trainer': lambda model, optimizer, criterion, epoch, callback : model,
'evaluator': lambda model: 0.9,
'dummy_input': torch.randn([64, 1, 28, 28]),
'validators': []
},
'autocompress_fpgm': {
'pruner_class': AutoCompressPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
}],
'base_algo': 'fpgm',
'trainer': lambda model, optimizer, criterion, epoch, callback : model, 'trainer': lambda model, optimizer, criterion, epoch, callback : model,
'evaluator': lambda model: 0.9, 'evaluator': lambda model: 0.9,
'dummy_input': torch.randn([64, 1, 28, 28]), 'dummy_input': torch.randn([64, 1, 28, 28]),
...@@ -181,7 +206,7 @@ class Model(nn.Module): ...@@ -181,7 +206,7 @@ class Model(nn.Module):
def forward(self, x): def forward(self, x):
return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1)) return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1))
def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'taylorfo', 'mean_activation', 'apoz', 'netadapt', 'simulatedannealing', 'admm', 'autocompress'], bias=True): def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'taylorfo', 'mean_activation', 'apoz', 'netadapt', 'simulatedannealing', 'admm', 'autocompress_l1', 'autocompress_l2', 'autocompress_fpgm',], bias=True):
for pruner_name in pruner_names: for pruner_name in pruner_names:
print('testing {}...'.format(pruner_name)) print('testing {}...'.format(pruner_name))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...@@ -203,8 +228,8 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl ...@@ -203,8 +228,8 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, evaluator=prune_config[pruner_name]['evaluator']) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, evaluator=prune_config[pruner_name]['evaluator'])
elif pruner_name == 'admm': elif pruner_name == 'admm':
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer']) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'])
elif pruner_name == 'autocompress': elif pruner_name.startswith('autocompress'):
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x, base_algo=prune_config[pruner_name]['base_algo'])
else: else:
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer)
pruner.compress() pruner.compress()
...@@ -264,7 +289,6 @@ class SimpleDataset: ...@@ -264,7 +289,6 @@ class SimpleDataset:
def __len__(self): def __len__(self):
return 1000 return 1000
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class PrunerTestCase(TestCase): class PrunerTestCase(TestCase):
def test_pruners(self): def test_pruners(self):
pruners_test(bias=True) pruners_test(bias=True)
...@@ -273,7 +297,7 @@ class PrunerTestCase(TestCase): ...@@ -273,7 +297,7 @@ class PrunerTestCase(TestCase):
pruners_test(bias=False) pruners_test(bias=False)
def test_agp_pruner(self): def test_agp_pruner(self):
for pruning_algorithm in ['l1', 'l2', 'taylorfo', 'apoz']: for pruning_algorithm in ['l1', 'l2', 'fpgm', 'taylorfo', 'apoz']:
_test_agp(pruning_algorithm) _test_agp(pruning_algorithm)
for pruning_algorithm in ['level']: for pruning_algorithm in ['level']:
......
...@@ -38,8 +38,7 @@ class MnistNetwork(object): ...@@ -38,8 +38,7 @@ class MnistNetwork(object):
input_dim = int(math.sqrt(self.x_dim)) input_dim = int(math.sqrt(self.x_dim))
except: except:
logger.debug( logger.debug(
'input dim cannot be sqrt and reshape. input dim: ' + 'input dim cannot be sqrt and reshape. input dim: ', str(self.x_dim))
str(self.x_dim))
raise raise
x_image = tf.reshape(self.x, [-1, input_dim, input_dim, 1]) x_image = tf.reshape(self.x, [-1, input_dim, input_dim, 1])
with tf.name_scope('conv1'): with tf.name_scope('conv1'):
...@@ -132,7 +131,7 @@ def main(): ...@@ -132,7 +131,7 @@ def main():
mnist_network.build_network() mnist_network.build_network()
logger.debug('Mnist build network done.') logger.debug('Mnist build network done.')
graph_location = tempfile.mkdtemp() graph_location = tempfile.mkdtemp()
logger.debug('Saving graph to: %s' % graph_location) logger.debug('Saving graph to: %s', graph_location)
train_writer = tf.summary.FileWriter(graph_location) train_writer = tf.summary.FileWriter(graph_location)
train_writer.add_graph(tf.get_default_graph()) train_writer.add_graph(tf.get_default_graph())
test_acc = 0.0 test_acc = 0.0
......
...@@ -53,7 +53,7 @@ class MnistNetwork(object): ...@@ -53,7 +53,7 @@ class MnistNetwork(object):
input_dim = int(math.sqrt(self.x_dim)) input_dim = int(math.sqrt(self.x_dim))
except: except:
#print('input dim cannot be sqrt and reshape. input dim: ' + str(self.x_dim)) #print('input dim cannot be sqrt and reshape. input dim: ' + str(self.x_dim))
logger.debug('input dim cannot be sqrt and reshape. input dim: ' + str(self.x_dim)) logger.debug('input dim cannot be sqrt and reshape. input dim: ', str(self.x_dim))
raise raise
x_image = tf.reshape(self.x, [-1, input_dim, input_dim, 1]) x_image = tf.reshape(self.x, [-1, input_dim, input_dim, 1])
...@@ -147,7 +147,7 @@ def main(): ...@@ -147,7 +147,7 @@ def main():
# Write log # Write log
graph_location = tempfile.mkdtemp() graph_location = tempfile.mkdtemp()
logger.debug('Saving graph to: %s' % graph_location) logger.debug('Saving graph to: %s', graph_location)
# print('Saving graph to: %s' % graph_location) # print('Saving graph to: %s' % graph_location)
train_writer = tf.summary.FileWriter(graph_location) train_writer = tf.summary.FileWriter(graph_location)
train_writer.add_graph(tf.get_default_graph()) train_writer.add_graph(tf.get_default_graph())
......
...@@ -11,9 +11,9 @@ from nni.tools.nnictl.nnictl_utils import get_yml_content ...@@ -11,9 +11,9 @@ from nni.tools.nnictl.nnictl_utils import get_yml_content
def create_mock_experiment(): def create_mock_experiment():
nnictl_experiment_config = Experiments() nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', '1970/01/1 01:01:01', 'aGew0x', nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', 123456,
'local', 'example_sklearn-classification') 'local', 'example_sklearn-classification')
nni_config = Config('aGew0x') nni_config = Config('xOpEwA5w')
# mock process # mock process
cmds = ['sleep', '3600000'] cmds = ['sleep', '3600000']
process = Popen(cmds, stdout=PIPE, stderr=STDOUT) process = Popen(cmds, stdout=PIPE, stderr=STDOUT)
......
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
from pathlib import Path from pathlib import Path
from subprocess import Popen, PIPE, STDOUT from subprocess import Popen, PIPE, STDOUT
from unittest import TestCase, main import sys
from unittest import TestCase, main, skipIf
from mock.restful_server import init_response from mock.restful_server import init_response
...@@ -28,8 +29,12 @@ class CommonUtilsTestCase(TestCase): ...@@ -28,8 +29,12 @@ class CommonUtilsTestCase(TestCase):
content = get_json_content(str(json_path)) content = get_json_content(str(json_path))
self.assertEqual(content, {'field':'test'}) self.assertEqual(content, {'field':'test'})
@skipIf(sys.platform == 'win32', 'FIXME: Fails randomly on Windows, cannot reproduce locally')
def test_detect_process(self): def test_detect_process(self):
cmds = ['sleep', '360000'] if sys.platform == 'win32':
cmds = ['timeout', '360000']
else:
cmds = ['sleep', '360000']
process = Popen(cmds, stdout=PIPE, stderr=STDOUT) process = Popen(cmds, stdout=PIPE, stderr=STDOUT)
self.assertTrue(detect_process(process.pid)) self.assertTrue(detect_process(process.pid))
kill_command(process.pid) kill_command(process.pid)
......
...@@ -19,7 +19,7 @@ class CommonUtilsTestCase(TestCase): ...@@ -19,7 +19,7 @@ class CommonUtilsTestCase(TestCase):
def test_update_experiment(self): def test_update_experiment(self):
experiment = Experiments(HOME_PATH) experiment = Experiments(HOME_PATH)
experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'aGew0x', 'local', 'test', endTime='N/A', status='INITIALIZED') experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'local', 'test', endTime='N/A', status='INITIALIZED')
self.assertTrue('xOpEwA5w' in experiment.get_all_experiments()) self.assertTrue('xOpEwA5w' in experiment.get_all_experiments())
experiment.remove_experiment('xOpEwA5w') experiment.remove_experiment('xOpEwA5w')
self.assertFalse('xOpEwA5w' in experiment.get_all_experiments()) self.assertFalse('xOpEwA5w' in experiment.get_all_experiments())
......
...@@ -46,7 +46,7 @@ class CommonUtilsTestCase(TestCase): ...@@ -46,7 +46,7 @@ class CommonUtilsTestCase(TestCase):
@responses.activate @responses.activate
def test_get_config_file_name(self): def test_get_config_file_name(self):
args = generate_args() args = generate_args()
self.assertEqual('aGew0x', get_config_filename(args)) self.assertEqual('xOpEwA5w', get_config_filename(args))
@responses.activate @responses.activate
def test_get_experiment_port(self): def test_get_experiment_port(self):
......
...@@ -43,9 +43,10 @@ interface MetricDataRecord { ...@@ -43,9 +43,10 @@ interface MetricDataRecord {
} }
interface TrialJobInfo { interface TrialJobInfo {
id: string; trialJobId: string;
sequenceId?: number; sequenceId?: number;
status: TrialJobStatus; status: TrialJobStatus;
message?: string;
startTime?: number; startTime?: number;
endTime?: number; endTime?: number;
hyperParameters?: string[]; hyperParameters?: string[];
...@@ -63,7 +64,7 @@ interface HyperParameterFormat { ...@@ -63,7 +64,7 @@ interface HyperParameterFormat {
interface ExportedDataFormat { interface ExportedDataFormat {
parameter: Record<string, any>; parameter: Record<string, any>;
value: Record<string, any>; value: Record<string, any>;
id: string; trialJobId: string;
} }
abstract class DataStore { abstract class DataStore {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment