search.py 2.4 KB
Newer Older
Chi Song's avatar
Chi Song committed
1
2
import logging
import time
3
4
5
6
7
8
9
10
11
from argparse import ArgumentParser

import torch
import torch.nn as nn

import datasets
from macro import GeneralNetwork
from micro import MicroNetwork
from nni.nas.pytorch import enas
12
from nni.nas.pytorch.callbacks import LRSchedulerCallback, ArchitectureCheckpoint
13
14
from utils import accuracy, reward_accuracy

Chi Song's avatar
Chi Song committed
15
16
17
18
19
20
21
22
23
24
25
logger = logging.getLogger()

fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p')

std_out_info = logging.StreamHandler()
std_out_info.setFormatter(formatter)
logger.setLevel(logging.INFO)
logger.addHandler(std_out_info)

26
27
28
if __name__ == "__main__":
    parser = ArgumentParser("enas")
    parser.add_argument("--batch-size", default=128, type=int)
29
    parser.add_argument("--log-frequency", default=10, type=int)
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")
    if args.search_for == "macro":
        model = GeneralNetwork()
        num_epochs = 310
        mutator = None
    elif args.search_for == "micro":
        model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True)
        num_epochs = 150
        mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
    else:
        raise AssertionError

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)

    trainer = enas.EnasTrainer(model,
                               loss=criterion,
                               metrics=accuracy,
                               reward_function=reward_accuracy,
                               optimizer=optimizer,
54
                               callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
55
56
57
58
                               batch_size=args.batch_size,
                               num_epochs=num_epochs,
                               dataset_train=dataset_train,
                               dataset_valid=dataset_valid,
59
60
61
                               log_frequency=args.log_frequency,
                               mutator=mutator)
    trainer.train()