Unverified Commit 40942f29 authored by Jiahang Xu's avatar Jiahang Xu Committed by GitHub
Browse files

refactor nn-Meter multi-trial to adapt new structure of nn-Meter (#3987)

parent 6d7efa19
...@@ -8,7 +8,7 @@ EndToEnd Multi-trial SPOS Demo ...@@ -8,7 +8,7 @@ EndToEnd Multi-trial SPOS Demo
Basically, this demo will select the model whose latency satisfy constraints to train. Basically, this demo will select the model whose latency satisfy constraints to train.
To run this demo, first install nn-Meter from source code (currently we haven't released this package, so development installation is required). To run this demo, first install nn-Meter from source code (Github repo link: https://github.com/microsoft/nn-Meter. Currently we haven't released this package, so development installation is required).
.. code-block:: bash .. code-block:: bash
......
...@@ -128,7 +128,7 @@ class ShuffleNetV2(nn.Module): ...@@ -128,7 +128,7 @@ class ShuffleNetV2(nn.Module):
class LatencyFilter: class LatencyFilter:
def __init__(self, threshold, config=None, hardware='', reverse=False): def __init__(self, threshold, predictor, predictor_version=None, reverse=False):
""" """
Filter the models according to predcted latency. Filter the models according to predcted latency.
...@@ -142,13 +142,7 @@ class LatencyFilter: ...@@ -142,13 +142,7 @@ class LatencyFilter:
if reverse is `False`, then the model returns `True` when `latency < threshold`, if reverse is `False`, then the model returns `True` when `latency < threshold`,
else otherwisse else otherwisse
""" """
default_config, default_hardware = get_default_config() self.predictors = load_latency_predictors(predictor, predictor_version)
if config is None:
config = default_config
if not hardware:
hardware = default_hardware
self.predictors = load_latency_predictors(config, hardware)
self.threshold = threshold self.threshold = threshold
def __call__(self, ir_model): def __call__(self, ir_model):
...@@ -160,6 +154,7 @@ class LatencyFilter: ...@@ -160,6 +154,7 @@ class LatencyFilter:
@click.option('--port', default=8081, help='On which port the experiment is run.') @click.option('--port', default=8081, help='On which port the experiment is run.')
def _main(port): def _main(port):
base_model = ShuffleNetV2(32) base_model = ShuffleNetV2(32)
base_predictor = 'cortexA76cpu_tflite21'
transf = [ transf = [
transforms.RandomCrop(32, padding=4), transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip() transforms.RandomHorizontalFlip()
...@@ -175,7 +170,7 @@ def _main(port): ...@@ -175,7 +170,7 @@ def _main(port):
val_dataloaders=pl.DataLoader(test_dataset, batch_size=64), val_dataloaders=pl.DataLoader(test_dataset, batch_size=64),
max_epochs=2, gpus=1) max_epochs=2, gpus=1)
simple_strategy = strategy.Random(model_filter=LatencyFilter(100)) simple_strategy = strategy.Random(model_filter=LatencyFilter(threshold=100, predictor=base_predictor))
exp = RetiariiExperiment(base_model, trainer, [], simple_strategy) exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment