search.py 2.52 KB
Newer Older
1
2
3
import random

import nni.retiarii.nn.pytorch as nn
4
import nni.retiarii.strategy as strategy
5
import nni.retiarii.evaluator.pytorch.lightning as pl
6
import torch.nn.functional as F
7
from nni.retiarii import serialize, model_wrapper
8
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment, debug_mutated_model
9
10
11
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
12

13
14
# uncomment this for python execution engine
# @model_wrapper
15
16
class Net(nn.Module):
    def __init__(self, hidden_size):
17
        super().__init__()
18
19
20
21
22
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.LayerChoice([
            nn.Linear(4*4*50, hidden_size),
            nn.Linear(4*4*50, hidden_size, bias=False)
23
        ], label='fc1_choice')
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        self.fc2 = nn.Linear(hidden_size, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


if __name__ == '__main__':
    base_model = Net(128)
39
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
40
41
    train_dataset = serialize(MNIST, root='data/mnist', train=True, download=True, transform=transform)
    test_dataset = serialize(MNIST, root='data/mnist', train=False, download=True, transform=transform)
42
43
44
    trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
                                val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
                                max_epochs=2)
45

46
47
48
49
    # uncomment the following two lines to debug a generated model
    #debug_mutated_model(base_model, trainer, [])
    #exit(0)

50
    simple_strategy = strategy.Random()
51

52
    exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)
53
54
55
56

    exp_config = RetiariiExeConfig('local')
    exp_config.experiment_name = 'mnist_search'
    exp_config.trial_concurrency = 2
57
    exp_config.max_trial_number = 2
58
    exp_config.training_service.use_active_gpu = False
59
60
61
62
63
    export_formatter = 'code'

    # uncomment this for python execution engine
    # exp_config.execution_engine = 'py'
    # export_formatter = 'dict'
64
65

    exp.run(exp_config, 8081 + random.randint(0, 100))
66
    print('Final model:')
67
    for model_code in exp.export_top_models(formatter=export_formatter):
68
        print(model_code)