test.py 2.3 KB
Newer Older
1
2
3
4
5
import os
import sys
import torch
from pathlib import Path

6
from nni.retiarii.trainer.pytorch import PyTorchImageClassificationTrainer
7

8
9
import nni.retiarii.trainer.pytorch.lightning as pl
from nni.retiarii import blackbox_module as bm
10
from base_mnasnet import MNASNet
11
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
12
from nni.retiarii.strategy import TPEStrategy
13
14
15
from torchvision import transforms
from torchvision.datasets import CIFAR10

16
17
18
19
20
21
22
23
24
25
from mutator import BlockMutator

if __name__ == '__main__':
    _DEFAULT_DEPTHS = [16, 24, 40, 80, 96, 192, 320]
    _DEFAULT_CONVOPS = ["dconv", "mconv", "mconv", "mconv", "mconv", "mconv", "mconv"]
    _DEFAULT_SKIPS = [False, True, True, True, True, True, True]
    _DEFAULT_KERNEL_SIZES = [3, 3, 5, 5, 3, 5, 3]
    _DEFAULT_NUM_LAYERS = [1, 3, 3, 3, 2, 4, 1]

    base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES,
26
                         _DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS)
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    train_dataset = bm(CIFAR10)(root='data/cifar10', train=True, download=True, transform=train_transform)
    test_dataset = bm(CIFAR10)(root='data/cifar10', train=False, download=True, transform=valid_transform)
    trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
                                val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
                                max_epochs=1, limit_train_batches=0.2)

    applied_mutators = [
        BlockMutator('mutable_0'),
        BlockMutator('mutable_1')
    ]
48

49
    simple_strategy = TPEStrategy()
50

51
    exp = RetiariiExperiment(base_model, trainer, applied_mutators, simple_strategy)
52
53
54
55
56
57
58

    exp_config = RetiariiExeConfig('local')
    exp_config.experiment_name = 'mnasnet_search'
    exp_config.trial_concurrency = 2
    exp_config.max_trial_number = 10
    exp_config.training_service.use_active_gpu = False

59
    exp.run(exp_config, 8081)