Unverified Commit 0918ea0c authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

NAS Benchmark integration (stage 1) (#4090)

parent 382286df
...@@ -30,9 +30,9 @@ class NasBench101(nn.Module): ...@@ -30,9 +30,9 @@ class NasBench101(nn.Module):
super().__init__() super().__init__()
op_candidates = { op_candidates = {
'conv3x3': lambda num_features: Conv3x3BnRelu(num_features, num_features), 'conv3x3-bn-relu': lambda num_features: Conv3x3BnRelu(num_features, num_features),
'conv1x1': lambda num_features: Conv1x1BnRelu(num_features, num_features), 'conv1x1-bn-relu': lambda num_features: Conv1x1BnRelu(num_features, num_features),
'maxpool': lambda num_features: nn.MaxPool2d(3, 1, 1) 'maxpool3x3': lambda num_features: nn.MaxPool2d(3, 1, 1)
} }
# initial stem convolution # initial stem convolution
...@@ -129,7 +129,8 @@ class NasBench101TrainingModule(pl.LightningModule): ...@@ -129,7 +129,8 @@ class NasBench101TrainingModule(pl.LightningModule):
@click.option('--epochs', default=108, help='Training length.') @click.option('--epochs', default=108, help='Training length.')
@click.option('--batch_size', default=256, help='Batch size.') @click.option('--batch_size', default=256, help='Batch size.')
@click.option('--port', default=8081, help='On which port the experiment is run.') @click.option('--port', default=8081, help='On which port the experiment is run.')
def _multi_trial_test(epochs, batch_size, port): @click.option('--benchmark', is_flag=True, default=False)
def _multi_trial_test(epochs, batch_size, port, benchmark):
# initalize dataset. Note that 50k+10k is used. It's a little different from paper # initalize dataset. Note that 50k+10k is used. It's a little different from paper
transf = [ transf = [
transforms.RandomCrop(32, padding=4), transforms.RandomCrop(32, padding=4),
...@@ -166,6 +167,10 @@ def _multi_trial_test(epochs, batch_size, port): ...@@ -166,6 +167,10 @@ def _multi_trial_test(epochs, batch_size, port):
exp_config.trial_gpu_number = 1 exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = False exp_config.training_service.use_active_gpu = False
if benchmark:
exp_config.benchmark = 'nasbench101'
exp_config.execution_engine = 'benchmark'
exp.run(exp_config, port) exp.run(exp_config, port)
......
...@@ -118,7 +118,8 @@ class NasBench201TrainingModule(pl.LightningModule): ...@@ -118,7 +118,8 @@ class NasBench201TrainingModule(pl.LightningModule):
@click.option('--epochs', default=12, help='Training length.') @click.option('--epochs', default=12, help='Training length.')
@click.option('--batch_size', default=256, help='Batch size.') @click.option('--batch_size', default=256, help='Batch size.')
@click.option('--port', default=8081, help='On which port the experiment is run.') @click.option('--port', default=8081, help='On which port the experiment is run.')
def _multi_trial_test(epochs, batch_size, port): @click.option('--benchmark', is_flag=True, default=False)
def _multi_trial_test(epochs, batch_size, port, benchmark):
# initalize dataset. Note that 50k+10k is used. It's a little different from paper # initalize dataset. Note that 50k+10k is used. It's a little different from paper
transf = [ transf = [
transforms.RandomCrop(32, padding=4), transforms.RandomCrop(32, padding=4),
...@@ -155,6 +156,10 @@ def _multi_trial_test(epochs, batch_size, port): ...@@ -155,6 +156,10 @@ def _multi_trial_test(epochs, batch_size, port):
exp_config.trial_gpu_number = 1 exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = False exp_config.training_service.use_active_gpu = False
if benchmark:
exp_config.benchmark = 'nasbench201-cifar100'
exp_config.execution_engine = 'benchmark'
exp.run(exp_config, port) exp.run(exp_config, port)
......
import os import os
# TODO: need to be refactored to support automatic download
DATABASE_DIR = os.environ.get("NASBENCHMARK_DIR", os.path.expanduser("~/.nni/nasbenchmark")) DATABASE_DIR = os.environ.get("NASBENCHMARK_DIR", os.path.expanduser("~/.nni/nasbenchmark"))
import os
import random
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable
from ..graph import Model
from ..integration_api import receive_trial_parameters
from .base import BaseExecutionEngine
from .python import get_mutation_dict
class BenchmarkGraphData:
SUPPORTED_BENCHMARK_LIST = [
'nasbench101',
'nasbench201-cifar10',
'nasbench201-cifar100',
'nasbench201-imagenet16',
'nds-cifar10',
'nds-imagenet',
'nlp'
]
def __init__(self, mutation: Dict[str, Any], benchmark: str,
metric_name: Optional[str] = None,
db_path: Optional[str] = None) -> None:
self.mutation = mutation # mutation dict. e.g., {'layer1': 'conv3x3', ...}
self.benchmark = benchmark # e.g., nasbench101, nasbench201, ...
self.db_path = db_path # path to directory of database
def dump(self) -> dict:
from nni.nas.benchmarks.constants import DATABASE_DIR
return {
'mutation': self.mutation,
'benchmark': self.benchmark,
'db_path': self.db_path or DATABASE_DIR # database path need to be passed from manager to worker
}
@staticmethod
def load(data) -> 'BenchmarkGraphData':
return BenchmarkGraphData(data['mutation'], data['benchmark'], data['metric_name'], data['db_path'])
class BenchmarkExecutionEngine(BaseExecutionEngine):
"""
Execution engine that does not actually run any trial, but query the database for results.
The database query is done on the trial end to make sure intermediate metrics are available.
It will also support an accelerated mode that returns metric immediately without even running into NNI manager
(not implemented yet).
"""
def __init__(self, benchmark: Union[str, Callable[[BenchmarkGraphData], Tuple[float, List[float]]]], acceleration: bool = False):
super().__init__()
assert benchmark in BenchmarkGraphData.SUPPORTED_BENCHMARK_LIST, \
f'{benchmark} is not one of the supported benchmarks: {BenchmarkGraphData.SUPPORTED_BENCHMARK_LIST}'
self.benchmark = benchmark
self.acceleration = acceleration
def pack_model_data(self, model: Model) -> Any:
# called when a new model is submitted to backend.
# convert a Model into a data that is acceptable by trial end.
mutation = get_mutation_dict(model)
graph_data = BenchmarkGraphData(mutation, self.benchmark)
return graph_data
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = BenchmarkGraphData.load(receive_trial_parameters())
os.environ['NASBENCHMARK_DIR'] = graph_data.db_path
final, intermediates = cls.query_in_benchmark(graph_data)
import nni
for i in intermediates:
nni.report_intermediate_result(i)
nni.report_final_result(final)
@staticmethod
def query_in_benchmark(graph_data: BenchmarkGraphData) -> Tuple[float, List[float]]:
if not isinstance(graph_data.benchmark, str):
return graph_data.benchmark(graph_data)
# built-in benchmarks with default query setting
if graph_data.benchmark == 'nasbench101':
from nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats
arch = None
for t in graph_data.mutation.values():
if isinstance(t, dict):
arch = t
if arch is None:
raise ValueError(f'Cannot identify architecture from mutation dict: {graph_data.mutation}')
print(arch)
return _convert_to_final_and_intermediates(
query_nb101_trial_stats(arch, 108, include_intermediates=True),
'valid_acc'
)
elif graph_data.benchmark.startswith('nasbench201'):
from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats
dataset = graph_data.benchmark.split('-')[-1]
return _convert_to_final_and_intermediates(
query_nb201_trial_stats(_flatten_architecture(graph_data.mutation), 200, dataset, include_intermediates=True),
'valid_acc',
)
elif graph_data.benchmark.startswith('nds'):
# FIXME: not tested yet
from nni.nas.benchmarks.nds import query_nds_trial_stats
dataset = graph_data.benchmark.split('-')[-1]
return _convert_to_final_and_intermediates(
query_nds_trial_stats(None, None, None, None, _flatten_architecture(graph_data.mutation),
dataset, include_intermediates=True),
'valid_acc'
)
elif graph_data.benchmark.startswith('nlp'):
# FIXME: not tested yet
from nni.nas.benchmarks.nlp import query_nlp_trial_stats
# TODO: I'm not sure of the availble datasets in this benchmark. and the docs are missing.
return _convert_to_final_and_intermediates(
query_nlp_trial_stats(_flatten_architecture(graph_data.mutation), 'ptb', include_intermediates=True),
'valid_acc'
)
else:
raise ValueError(f'{graph_data.benchmark} is not a supported benchmark.')
def _flatten_architecture(mutation: Dict[str, Any], benchmark: Optional[str] = None):
# STRONG ASSUMPTION HERE!
# This assumes that the benchmarked search space is a one-level search space.
# This means that it is either ONE cell or ONE network.
# Two cell search space like NDS is not supported yet for now.
# Some benchmark even needs special handling to pop out invalid keys. I don't think this is a good design.
# support double underscore to be compatible with naming convention in base engine
ret = {k.split('/')[-1].split('__')[-1]: v for k, v in mutation.items()}
if benchmark == 'nasbench101':
ret = {k: v for k, v in ret.items() if k.startswith('op') or k.startswith('input')}
ret = {k: v if k.startswith('op') or isinstance(v, list) else [v] for k, v in ret.items()}
return ret
def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_name: str) -> Tuple[float, List[float]]:
# convert benchmark results from database to
# final result (float) and intermediate results (list of floats)
benchmark_result = list(benchmark_result)
assert len(benchmark_result) > 0, 'Invalid query. Results from benchmark is empty.'
if len(benchmark_result) > 1:
benchmark_result = random.choice(benchmark_result)
else:
benchmark_result = benchmark_result[0]
return benchmark_result[metric_name], [i[metric_name] for i in benchmark_result['intermediates'] if i[metric_name] is not None]
...@@ -28,6 +28,16 @@ class PythonGraphData: ...@@ -28,6 +28,16 @@ class PythonGraphData:
class PurePythonExecutionEngine(BaseExecutionEngine): class PurePythonExecutionEngine(BaseExecutionEngine):
"""
This is the execution engine that doesn't rely on Python-IR converter.
We didn't explicitly state this independency for now. Front-end needs to decide which converter / no converter
to use depending on the execution type. In the future, that logic may be moved into this execution engine.
The execution engine needs to store the class path of base model, and init parameters to re-initialize the model
with the mutation dict in the context, so that the mutable modules are created to be the fixed instance on the fly.
"""
@classmethod @classmethod
def pack_model_data(cls, model: Model) -> Any: def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model) mutation = get_mutation_dict(model)
......
...@@ -66,6 +66,9 @@ class RetiariiExeConfig(ConfigBase): ...@@ -66,6 +66,9 @@ class RetiariiExeConfig(ConfigBase):
# input used in GraphConverterWithShape. Currently support shape tuple only. # input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None dummy_input: Optional[List[int]] = None
# input used for benchmark engine.
benchmark: Optional[str] = None
def __init__(self, training_service_platform: Optional[str] = None, **kwargs): def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if training_service_platform is not None: if training_service_platform is not None:
...@@ -82,7 +85,7 @@ class RetiariiExeConfig(ConfigBase): ...@@ -82,7 +85,7 @@ class RetiariiExeConfig(ConfigBase):
if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)): if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)):
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!') raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
if key == 'execution_engine': if key == 'execution_engine':
assert value in ['base', 'py', 'cgo'], f'The specified execution engine "{value}" is not supported.' assert value in ['base', 'py', 'cgo', 'benchmark'], f'The specified execution engine "{value}" is not supported.'
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value
self.__dict__[key] = value self.__dict__[key] = value
...@@ -186,8 +189,10 @@ class RetiariiExperiment(Experiment): ...@@ -186,8 +189,10 @@ class RetiariiExperiment(Experiment):
def _start_strategy(self): def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model( base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py', self.base_model, self.trainer, self.applied_mutators,
dummy_input=self.config.dummy_input) full_ir=self.config.execution_engine not in ['py', 'benchmark'],
dummy_input=self.config.dummy_input
)
_logger.info('Start strategy...') _logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators) self.strategy.run(base_model_ir, self.applied_mutators)
...@@ -224,6 +229,9 @@ class RetiariiExperiment(Experiment): ...@@ -224,6 +229,9 @@ class RetiariiExperiment(Experiment):
elif self.config.execution_engine == 'py': elif self.config.execution_engine == 'py':
from ..execution.python import PurePythonExecutionEngine from ..execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine() engine = PurePythonExecutionEngine()
elif self.config.execution_engine == 'benchmark':
from ..execution.benchmark import BenchmarkExecutionEngine
engine = BenchmarkExecutionEngine(self.config.benchmark)
set_execution_engine(engine) set_execution_engine(engine)
self.id = management.generate_experiment_id() self.id = management.generate_experiment_id()
......
...@@ -210,7 +210,7 @@ class NasBench201Cell(nn.Module): ...@@ -210,7 +210,7 @@ class NasBench201Cell(nn.Module):
inp = in_features if j == 0 else out_features inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features)) op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()]) for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
self.layers.append(node_ops) self.layers.append(node_ops)
def forward(self, inputs): def forward(self, inputs):
......
...@@ -298,11 +298,11 @@ class NasBench101Cell(nn.Module): ...@@ -298,11 +298,11 @@ class NasBench101Cell(nn.Module):
label, selected = get_fixed_dict(label) label, selected = get_fixed_dict(label)
op_candidates = cls._make_dict(op_candidates) op_candidates = cls._make_dict(op_candidates)
num_nodes = selected[f'{label}/num_nodes'] num_nodes = selected[f'{label}/num_nodes']
adjacency_list = [make_list(selected[f'{label}/input_{i}']) for i in range(1, num_nodes)] adjacency_list = [make_list(selected[f'{label}/input{i}']) for i in range(1, num_nodes)]
if sum([len(e) for e in adjacency_list]) > max_num_edges: if sum([len(e) for e in adjacency_list]) > max_num_edges:
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}') raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
return _NasBench101CellFixed( return _NasBench101CellFixed(
[op_candidates[selected[f'{label}/op_{i}']] for i in range(1, num_nodes - 1)], [op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)],
adjacency_list, in_features, out_features, num_nodes, projection) adjacency_list, in_features, out_features, num_nodes, projection)
except NoContextError: except NoContextError:
return super().__new__(cls) return super().__new__(cls)
...@@ -334,8 +334,8 @@ class NasBench101Cell(nn.Module): ...@@ -334,8 +334,8 @@ class NasBench101Cell(nn.Module):
for i in range(1, max_num_nodes): for i in range(1, max_num_nodes):
if i < max_num_nodes - 1: if i < max_num_nodes - 1:
self.ops.append(LayerChoice(OrderedDict([(k, op(self.hidden_features)) for k, op in op_candidates.items()]), self.ops.append(LayerChoice(OrderedDict([(k, op(self.hidden_features)) for k, op in op_candidates.items()]),
label=f'{self._label}/op_{i}')) label=f'{self._label}/op{i}'))
self.inputs.append(InputChoice(i, None, label=f'{self._label}/input_{i}')) self.inputs.append(InputChoice(i, None, label=f'{self._label}/input{i}'))
@property @property
def label(self): def label(self):
...@@ -380,11 +380,22 @@ class NasBench101Mutator(Mutator): ...@@ -380,11 +380,22 @@ class NasBench101Mutator(Mutator):
break break
mutation_dict = {mut.mutator.label: mut.samples for mut in model.history} mutation_dict = {mut.mutator.label: mut.samples for mut in model.history}
num_nodes = mutation_dict[f'{self.label}/num_nodes'][0] num_nodes = mutation_dict[f'{self.label}/num_nodes'][0]
adjacency_list = [mutation_dict[f'{self.label}/input_{i}'] for i in range(1, num_nodes)] adjacency_list = [mutation_dict[f'{self.label}/input{i}'] for i in range(1, num_nodes)]
if sum([len(e) for e in adjacency_list]) > max_num_edges: if sum([len(e) for e in adjacency_list]) > max_num_edges:
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}') raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
matrix = _NasBench101CellFixed.build_connection_matrix(adjacency_list, num_nodes) matrix = _NasBench101CellFixed.build_connection_matrix(adjacency_list, num_nodes)
prune(matrix, [None] * len(matrix)) # dummy ops, possible to raise InvalidMutation inside
operations = ['IN'] + [mutation_dict[f'{self.label}/op{i}'][0] for i in range(1, num_nodes - 1)] + ['OUT']
assert len(operations) == len(matrix)
matrix, operations = prune(matrix, operations) # possible to raise InvalidMutation inside
# NOTE: a hack to maintain a clean copy of what nasbench101 cell looks like
self._cur_samples = {}
for i in range(1, len(matrix)):
if i + 1 < len(matrix):
self._cur_samples[f'op{i}'] = operations[i]
self._cur_samples[f'input{i}'] = [k for k in range(i) if matrix[k, i]]
self._cur_samples = [self._cur_samples] # by design, _cur_samples is a list of samples
def dry_run(self, model): def dry_run(self, model):
return [], model return [], model
...@@ -11,7 +11,7 @@ import argparse ...@@ -11,7 +11,7 @@ import argparse
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('exec', choices=['base', 'py', 'cgo']) parser.add_argument('exec', choices=['base', 'py', 'cgo', 'benchmark'])
args = parser.parse_args() args = parser.parse_args()
if args.exec == 'base': if args.exec == 'base':
from .execution.base import BaseExecutionEngine from .execution.base import BaseExecutionEngine
...@@ -22,4 +22,7 @@ if __name__ == '__main__': ...@@ -22,4 +22,7 @@ if __name__ == '__main__':
elif args.exec == 'py': elif args.exec == 'py':
from .execution.python import PurePythonExecutionEngine from .execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine engine = PurePythonExecutionEngine
elif args.exec == 'benchmark':
from .execution.benchmark import BenchmarkExecutionEngine
engine = BenchmarkExecutionEngine
engine.trial_execute_graph() engine.trial_execute_graph()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment