# This file is to demo the usage of multi-trial NAS in the usage of SPOS search space.

import click
import json
import nni.retiarii.evaluator.pytorch as pl
import nni.retiarii.strategy as strategy
from nni.retiarii import serialize
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from torchvision import transforms
from torchvision.datasets import CIFAR10
from nn_meter import load_latency_predictor

from network import ShuffleNetV2OneShot
from utils import get_archchoice_by_model


class LatencyFilter:
    def __init__(self, threshold, predictor, predictor_version=None, reverse=False):
        """
        Filter the models according to predicted latency.

        Parameters
        ----------
        threshold: `float`
            the threshold of latency
        config, hardware:
            determine the targeted device
        reverse: `bool`
            if reverse is `False`, then the model returns `True` when `latency < threshold`,
            else otherwise
        """
        self.predictors = load_latency_predictor(predictor, predictor_version)
        self.threshold = threshold

    def __call__(self, ir_model):
        latency = self.predictors.predict(ir_model, 'nni-ir')
        return latency < self.threshold


@click.command()
@click.option('--port', default=8081, help='On which port the experiment is run.')
def _main(port):
    base_model = ShuffleNetV2OneShot(32)
    base_predictor = 'cortexA76cpu_tflite21'
    transf = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip()
    ]
    normalize = [
        transforms.ToTensor(),
        transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768])
    ]
    # FIXME
    # CIFAR10 is used here temporarily.
    # Actually we should load weight from supernet and evaluate on imagenet.
    train_dataset = serialize(CIFAR10, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize))
    test_dataset = serialize(CIFAR10, 'data', train=False, transform=transforms.Compose(normalize))

    trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=64),
                                val_dataloaders=pl.DataLoader(test_dataset, batch_size=64),
                                max_epochs=2, gpus=1)

    simple_strategy = strategy.RegularizedEvolution(model_filter=LatencyFilter(threshold=100, predictor=base_predictor),
                                                    sample_size=1, population_size=2, cycles=2)
    exp = RetiariiExperiment(base_model, trainer, strategy=simple_strategy)

    exp_config = RetiariiExeConfig('local')
    exp_config.trial_concurrency = 2
    # exp_config.max_trial_number = 2
    exp_config.trial_gpu_number = 1
    exp_config.training_service.use_active_gpu = False
    exp_config.execution_engine = 'base'
    exp_config.dummy_input = [1, 3, 32, 32]

    exp.run(exp_config, port)

    print('Exported models:')
    for i, model in enumerate(exp.export_top_models(formatter='dict')):
        print(model)
        with open(f'architecture_final_{i}.json', 'w') as f: 
            json.dump(get_archchoice_by_model(model), f, indent=4)


if __name__ == '__main__':
    _main()
