Unverified Commit 40942f29 authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

refactor nn-Meter multi-trial to adapt new structure of nn-Meter (#3987)

parent 6d7efa19
......@@ -8,7 +8,7 @@ EndToEnd Multi-trial SPOS Demo
Basically, this demo will select the model whose latency satisfy constraints to train.
To run this demo, first install nn-Meter from source code (currently we haven't released this package, so development installation is required).
To run this demo, first install nn-Meter from source code (Github repo link: https://github.com/microsoft/nn-Meter. Currently we haven't released this package, so development installation is required).
.. code-block:: bash
......
......@@ -128,7 +128,7 @@ class ShuffleNetV2(nn.Module):
class LatencyFilter:
def __init__(self, threshold, config=None, hardware='', reverse=False):
def __init__(self, threshold, predictor, predictor_version=None, reverse=False):
"""
Filter the models according to predcted latency.
......@@ -142,13 +142,7 @@ class LatencyFilter:
if reverse is `False`, then the model returns `True` when `latency < threshold`,
else otherwisse
"""
default_config, default_hardware = get_default_config()
if config is None:
config = default_config
if not hardware:
hardware = default_hardware
self.predictors = load_latency_predictors(config, hardware)
self.predictors = load_latency_predictors(predictor, predictor_version)
self.threshold = threshold
def __call__(self, ir_model):
......@@ -160,6 +154,7 @@ class LatencyFilter:
@click.option('--port', default=8081, help='On which port the experiment is run.')
def _main(port):
base_model = ShuffleNetV2(32)
base_predictor = 'cortexA76cpu_tflite21'
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
......@@ -175,7 +170,7 @@ def _main(port):
val_dataloaders=pl.DataLoader(test_dataset, batch_size=64),
max_epochs=2, gpus=1)
simple_strategy = strategy.Random(model_filter=LatencyFilter(100))
simple_strategy = strategy.Random(model_filter=LatencyFilter(threshold=100, predictor=base_predictor))
exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment