test.py 2.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import os
import sys
import torch
from pathlib import Path

sys.path.append(str(Path(__file__).resolve().parents[2]))
from nni.retiarii.converter.graph_gen import convert_to_graph
from nni.retiarii.converter.visualize import visualize_model
from nni.retiarii.codegen.pytorch import model_to_pytorch_script

QuanluZhang's avatar
QuanluZhang committed
11
12
13
14
from nni.retiarii import nn
from nni.retiarii.trainer import PyTorchImageClassificationTrainer
from nni.retiarii.utils import TraceClassArguments

15
from base_mnasnet import MNASNet
QuanluZhang's avatar
QuanluZhang committed
16
17
18
19
20
21
from nni.experiment import RetiariiExperiment, RetiariiExpConfig

#from simple_strategy import SimpleStrategy
#from tpe_strategy import TPEStrategy
from nni.retiarii.strategies import TPEStrategy
from mutator import BlockMutator
22
23
24
25
26
27
28

if __name__ == '__main__':
    _DEFAULT_DEPTHS = [16, 24, 40, 80, 96, 192, 320]
    _DEFAULT_CONVOPS = ["dconv", "mconv", "mconv", "mconv", "mconv", "mconv", "mconv"]
    _DEFAULT_SKIPS = [False, True, True, True, True, True, True]
    _DEFAULT_KERNEL_SIZES = [3, 3, 5, 5, 3, 5, 3]
    _DEFAULT_NUM_LAYERS = [1, 3, 3, 3, 2, 4, 1]
QuanluZhang's avatar
QuanluZhang committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42

    with TraceClassArguments() as tca:
        base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES,
                        _DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS)
        trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
                dataset_kwargs={"root": "data/cifar10", "download": True},
                dataloader_kwargs={"batch_size": 32},
                optimizer_kwargs={"lr": 1e-3},
                trainer_kwargs={"max_epochs": 1})

    '''script_module = torch.jit.script(base_model)
    model = convert_to_graph(script_module, base_model, tca.recorded_arguments)
    code_script = model_to_pytorch_script(model)
    print(code_script)
43
44
45
    print("Model: ", model)
    graph_ir = model._dump()
    print(graph_ir)
QuanluZhang's avatar
QuanluZhang committed
46
47
48
49
50
51
52
53
54
55
    visualize_model(graph_ir)'''

    # new interface
    applied_mutators = []
    applied_mutators.append(BlockMutator('mutable_0'))
    applied_mutators.append(BlockMutator('mutable_1'))

    simple_startegy = TPEStrategy()

    exp = RetiariiExperiment(base_model, trainer, applied_mutators, simple_startegy, tca)
56

QuanluZhang's avatar
QuanluZhang committed
57
58
59
60
    exp_config = RetiariiExpConfig.create_template('local')
    exp_config.experiment_name = 'mnasnet_search'
    exp_config.trial_concurrency = 2
    exp_config.max_trial_number = 10
61

QuanluZhang's avatar
QuanluZhang committed
62
    exp.run(exp_config, 8081, debug=True)