Unverified Commit b99e2683 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Migration of NAS tests (#4933)

parent c0239e9d
...@@ -503,8 +503,18 @@ def _trace_cls(base, kw_only, call_super=True, inheritable=False): ...@@ -503,8 +503,18 @@ def _trace_cls(base, kw_only, call_super=True, inheritable=False):
# store a copy of initial parameters # store a copy of initial parameters
args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True) args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True)
try:
# calling serializable object init to initialize the full object # calling serializable object init to initialize the full object
super().__init__(symbol=base, args=args, kwargs=kwargs, call_super=call_super) super().__init__(symbol=base, args=args, kwargs=kwargs, call_super=call_super)
except RecursionError as e:
warnings.warn(
'Recursion error detected in initialization of wrapped object. '
'Did you use `super(MyClass, self).__init__()` rather than `super().__init__()`? '
'Please use `super().__init__()` and try again. '
f'Original error: {e}',
RuntimeWarning
)
raise
def __reduce__(self): def __reduce__(self):
# The issue that decorator and pickler doesn't play well together is well known. # The issue that decorator and pickler doesn't play well together is well known.
...@@ -771,6 +781,11 @@ def _get_cls_or_func_name(cls_or_func: Any) -> str: ...@@ -771,6 +781,11 @@ def _get_cls_or_func_name(cls_or_func: Any) -> str:
def get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096) -> str: def get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096) -> str:
"""Pickle a class or function object to a string.
It will first try to picklize the object with an importable path.
If that doesn't work out, it fallbacks to cloudpickle.
"""
try: try:
name = _get_cls_or_func_name(cls_or_func) name = _get_cls_or_func_name(cls_or_func)
# import success, use a path format # import success, use a path format
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import os import os
import sys
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Union from typing import Any, Dict, Union, Optional
from nni.experiment.config import utils, ExperimentConfig from nni.experiment.config import utils, ExperimentConfig
...@@ -33,6 +34,10 @@ class RetiariiExeConfig(ExperimentConfig): ...@@ -33,6 +34,10 @@ class RetiariiExeConfig(ExperimentConfig):
# new config field for NAS # new config field for NAS
execution_engine: Union[str, ExecutionEngineConfig] execution_engine: Union[str, ExecutionEngineConfig]
# Internal: to support customized fields in trial command
# Useful when customized python / environment variables are needed
_trial_command_params: Optional[Dict[str, Any]] = None
def __init__(self, training_service_platform: Union[str, None] = None, def __init__(self, training_service_platform: Union[str, None] = None,
execution_engine: Union[str, ExecutionEngineConfig] = 'py', execution_engine: Union[str, ExecutionEngineConfig] = 'py',
**kwargs): **kwargs):
...@@ -46,15 +51,25 @@ class RetiariiExeConfig(ExperimentConfig): ...@@ -46,15 +51,25 @@ class RetiariiExeConfig(ExperimentConfig):
# TODO: maybe we should also allow users to specify trial_code_directory # TODO: maybe we should also allow users to specify trial_code_directory
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory): if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory)) raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
if self.trial_command != '_reserved' and \
not self.trial_command.startswith('python3 -m nni.retiarii.trial_entry '): trial_command_tmpl = '{envs} {python} -m nni.retiarii.trial_entry {execution_engine}'
if self.trial_command != '_reserved' and '-m nni.retiarii.trial_entry' not in self.trial_command:
raise ValueError(msg.format('trial_command', self.trial_command)) raise ValueError(msg.format('trial_command', self.trial_command))
if isinstance(self.execution_engine, str): if isinstance(self.execution_engine, str):
self.execution_engine = execution_engine_config_factory(self.execution_engine) self.execution_engine = execution_engine_config_factory(self.execution_engine)
if self.execution_engine.name in ('py', 'base', 'cgo'):
# TODO: replace python3 with more elegant approach _trial_command_params = {
# maybe use sys.executable rendered in trial side (e.g., trial_runner) # Default variables
self.trial_command = 'python3 -m nni.retiarii.trial_entry ' + self.execution_engine.name 'envs': '',
# TODO: maybe use sys.executable rendered in trial side (e.g., trial_runner)
'python': sys.executable,
'execution_engine': self.execution_engine.name,
# This should override the parameters above.
**(self._trial_command_params or {})
}
self.trial_command = trial_command_tmpl.format(**_trial_command_params).strip()
super()._canonicalize([self]) super()._canonicalize([self])
...@@ -13,7 +13,7 @@ resources: ...@@ -13,7 +13,7 @@ resources:
endpoint: github-filter-connection endpoint: github-filter-connection
variables: variables:
filter.modified.globs: 'examples/nas/**,nni/algorithms/nas/**,nni/nas/**,nni/retiarii/**' filter.modified.globs: 'examples/nas/**,nni/algorithms/nas/**,nni/nas/**,nni/retiarii/**,pipelines/full-test-nas.yml,test/ut/nas/**,test/algo/nas/**'
filter.prbody.heading: '#### Test Options' filter.prbody.heading: '#### Test Options'
filter.prbody.optionIndex: 2 filter.prbody.optionIndex: 2
...@@ -42,7 +42,36 @@ stages: ...@@ -42,7 +42,36 @@ stages:
- template: templates/install-nni.yml - template: templates/install-nni.yml
- template: templates/download-test-data.yml
- script: | - script: |
cd test cd test
source scripts/nas.sh python -m pytest algo/nas
displayName: NAS test
- job: windows
pool: nni-it-windows
timeoutInMinutes: 60
steps:
- template: templates/install-dependencies.yml
parameters:
platform: windows
python_env: noop
- template: templates/install-nni.yml
parameters:
user: false
# NOTE: Data needs to be downloaded if Windows has GPU.
# Also, the download template needs to be updated with powershell syntax.
# - template: templates/download-test-data.yml
- powershell: |
python test/vso_tools/ssl_patch.py
displayName: SSL patch
- powershell: |
cd test
python -m pytest algo/nas
displayName: NAS test displayName: NAS test
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Extra tests for "algorithms", complementary to UT.
If the test satisfies one of the following conditions, it should be put here:
1. The test could take a while to finish.
2. The test doesn't work on the free agent. It needs accelerators like GPUs.
3. The test is dedicated for a specific replacable module, which doesn't involve core functionalities.
Note that if a test is to ensure the correctness of a "core function", without which NNI doesn't work at all,
it's still highly recommended to include at least a simple test in UT.
If a set of exhaustive tests were to be expensive, they can still belong here.
"""
# Import ut to set environment variables
import ut
...@@ -69,11 +69,6 @@ class TestConvert(unittest.TestCase, ConvertMixin): ...@@ -69,11 +69,6 @@ class TestConvert(unittest.TestCase, ConvertMixin):
self.assertLess((a - b).abs().max().item(), 1E-4) self.assertLess((a - b).abs().max().item(), 1E-4)
return converted_model return converted_model
def setUp(self):
# FIXME
import nni.retiarii.debug_configs
nni.retiarii.debug_configs.framework = 'pytorch'
def test_dcgan_models(self): def test_dcgan_models(self):
class DCGANGenerator(nn.Module): class DCGANGenerator(nn.Module):
def __init__(self, nz, ngf, nc): def __init__(self, nz, ngf, nc):
......
...@@ -162,7 +162,7 @@ def _new_trainer(): ...@@ -162,7 +162,7 @@ def _new_trainer():
def _load_mnist(n_models: int = 1): def _load_mnist(n_models: int = 1):
path = Path(__file__).parent / 'mnist_pytorch.json' path = Path('ut/nas/mnist_pytorch.json')
with open(path) as f: with open(path) as f:
mnist_model = Model._load(nni.load(fp=f)) mnist_model = Model._load(nni.load(fp=f))
mnist_model.evaluator = _new_trainer() mnist_model.evaluator = _new_trainer()
...@@ -306,7 +306,6 @@ class CGOEngineTest(unittest.TestCase): ...@@ -306,7 +306,6 @@ class CGOEngineTest(unittest.TestCase):
def test_submit_models(self): def test_submit_models(self):
_reset() _reset()
nni.retiarii.debug_configs.framework = 'pytorch'
os.makedirs('generated', exist_ok=True) os.makedirs('generated', exist_ok=True)
import nni.runtime.platform.test as tt import nni.runtime.platform.test as tt
protocol._set_out_file(open('generated/debug_protocol_out_file.py', 'wb')) protocol._set_out_file(open('generated/debug_protocol_out_file.py', 'wb'))
......
import multiprocessing
import os
import sys
import time
import pytest
import pytorch_lightning as pl
from nni.retiarii import strategy
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from ut.nas.test_experiment import nas_experiment_trial_params, ensure_success
from .test_oneshot import _mnist_net
pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
@pytest.mark.parametrize('model', [
'simple', 'simple_value_choice', 'value_choice', 'repeat', 'custom_op'
])
def test_multi_trial(model, pytestconfig):
evaluator_kwargs = {
'max_epochs': 1
}
base_model, evaluator = _mnist_net(model, evaluator_kwargs)
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_unittest'
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config._trial_command_params = nas_experiment_trial_params(pytestconfig.rootpath)
exp.run(exp_config)
ensure_success(exp)
assert isinstance(exp.export_top_models()[0], dict)
exp.stop()
def _test_experiment_in_separate_process(rootpath):
try:
base_model, evaluator = _mnist_net('simple', {'max_epochs': 1})
search_strategy = strategy.Random()
exp = RetiariiExperiment(base_model, evaluator, strategy=search_strategy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_unittest'
exp_config.trial_concurrency = 1
exp_config.max_trial_number = 1
exp_config._trial_command_params = nas_experiment_trial_params(rootpath)
exp.run(exp_config)
ensure_success(exp)
assert isinstance(exp.export_top_models()[0], dict)
finally:
# https://stackoverflow.com/questions/34506638/how-to-register-atexit-function-in-pythons-multiprocessing-subprocess
import atexit
atexit._run_exitfuncs()
def test_exp_exit_without_stop(pytestconfig):
# NOTE: Multiprocessing has compatibility issue with OpenMP.
# It makes the MNIST dataset fails to load on pipeline.
# https://github.com/pytorch/pytorch/issues/50669
# Need to use spawn as a workaround of this issue.
ctx = multiprocessing.get_context('spawn')
process = ctx.Process(
target=_test_experiment_in_separate_process,
kwargs=dict(rootpath=pytestconfig.rootpath)
)
process.start()
print('Waiting for experiment in sub-process.')
timeout = 180
for _ in range(timeout):
if process.is_alive():
time.sleep(1)
else:
assert process.exitcode == 0
return
process.kill()
raise RuntimeError(f'Experiment fails to stop in {timeout} seconds.')
...@@ -217,11 +217,11 @@ def _mnist_net(type_, evaluator_kwargs): ...@@ -217,11 +217,11 @@ def _mnist_net(type_, evaluator_kwargs):
raise ValueError(f'Unsupported type: {type_}') raise ValueError(f'Unsupported type: {type_}')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = nni.trace(MNIST)('data/mnist', train=True, download=True, transform=transform) train_dataset = nni.trace(MNIST)('data/mnist', download=True, train=True, transform=transform)
# Multi-GPU combined dataloader will break this subset sampler. Expected though. # Multi-GPU combined dataloader will break this subset sampler. Expected though.
train_random_sampler = nni.trace(RandomSampler)(train_dataset, True, int(len(train_dataset) / 20)) train_random_sampler = nni.trace(RandomSampler)(train_dataset, True, int(len(train_dataset) / 20))
train_loader = nni.trace(DataLoader)(train_dataset, 64, sampler=train_random_sampler) train_loader = nni.trace(DataLoader)(train_dataset, 64, sampler=train_random_sampler)
valid_dataset = nni.trace(MNIST)('data/mnist', train=False, download=True, transform=transform) valid_dataset = nni.trace(MNIST)('data/mnist', download=True, train=False, transform=transform)
valid_random_sampler = nni.trace(RandomSampler)(valid_dataset, True, int(len(valid_dataset) / 20)) valid_random_sampler = nni.trace(RandomSampler)(valid_dataset, True, int(len(valid_dataset) / 20))
valid_loader = nni.trace(DataLoader)(valid_dataset, 64, sampler=valid_random_sampler) valid_loader = nni.trace(DataLoader)(valid_dataset, 64, sampler=valid_random_sampler)
evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **evaluator_kwargs) evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **evaluator_kwargs)
......
...@@ -17,7 +17,7 @@ from nni.retiarii.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLay ...@@ -17,7 +17,7 @@ from nni.retiarii.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLay
from nni.retiarii.oneshot.pytorch.supermodule._operation_utils import Slicable as S, MaybeWeighted as W from nni.retiarii.oneshot.pytorch.supermodule._operation_utils import Slicable as S, MaybeWeighted as W
from nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils import * from nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils import *
from .models import ( from ut.nas.models import (
CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment