search.py 2.42 KB
Newer Older
1
2
3
import random

import nni.retiarii.nn.pytorch as nn
4
import nni.retiarii.strategy as strategy
5
import nni.retiarii.evaluator.pytorch.lightning as pl
6
import torch.nn.functional as F
7
from nni.retiarii import serialize, model_wrapper
8
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment, debug_mutated_model
9
10
11
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
12

13
14
# comment the follwing line for graph-based execution engine
@model_wrapper
15
16
class Net(nn.Module):
    def __init__(self, hidden_size):
17
        super().__init__()
18
19
20
21
22
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.LayerChoice([
            nn.Linear(4*4*50, hidden_size),
            nn.Linear(4*4*50, hidden_size, bias=False)
23
        ], label='fc1_choice')
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        self.fc2 = nn.Linear(hidden_size, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


if __name__ == '__main__':
    base_model = Net(128)
39
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
40
41
    train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
    test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
42
43
44
    trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
                                val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
                                max_epochs=2)
45

46
    simple_strategy = strategy.Random()
47

48
    exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)
49
50
51
52

    exp_config = RetiariiExeConfig('local')
    exp_config.experiment_name = 'mnist_search'
    exp_config.trial_concurrency = 2
53
    exp_config.max_trial_number = 2
54
    exp_config.training_service.use_active_gpu = False
55
    export_formatter = 'dict'
56

57
58
59
    # uncomment this for graph-based execution engine
    # exp_config.execution_engine = 'base'
    # export_formatter = 'code'
60
61

    exp.run(exp_config, 8081 + random.randint(0, 100))
62
    print('Final model:')
63
    for model_code in exp.export_top_models(formatter=export_formatter):
64
        print(model_code)