search.py 3.35 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Chi Song's avatar
Chi Song committed
4
5
import logging
import time
6
7
8
9
10
11
12
13
from argparse import ArgumentParser

import torch
import torch.nn as nn

import datasets
from macro import GeneralNetwork
from micro import MicroNetwork
colorjam's avatar
colorjam committed
14
from nni.algorithms.nas.pytorch import enas
15
16
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
                                       LRSchedulerCallback)
17
18
from utils import accuracy, reward_accuracy

19
logger = logging.getLogger('nni')
Chi Song's avatar
Chi Song committed
20
21


22
23
24
if __name__ == "__main__":
    parser = ArgumentParser("enas")
    parser.add_argument("--batch-size", default=128, type=int)
25
    parser.add_argument("--log-frequency", default=10, type=int)
26
    parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
Yuge Zhang's avatar
Yuge Zhang committed
27
    parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
28
    parser.add_argument("--visualization", default=False, action="store_true")
29
    parser.add_argument("--v1", default=False, action="store_true")
30
31
32
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")
33
34
    mutator = None
    ctrl_kwargs = {}
35
36
    if args.search_for == "macro":
        model = GeneralNetwork()
Yuge Zhang's avatar
Yuge Zhang committed
37
        num_epochs = args.epochs or 310
38
    elif args.search_for == "micro":
39
        model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=False)
Yuge Zhang's avatar
Yuge Zhang committed
40
        num_epochs = args.epochs or 150
41
42
43
44
        if args.v1:
            mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
        else:
            ctrl_kwargs = {"tanh_constant": 1.1}
45
46
47
48
49
50
51
    else:
        raise AssertionError

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    if args.v1:
        trainer = enas.EnasTrainer(model,
                                   loss=criterion,
                                   metrics=accuracy,
                                   reward_function=reward_accuracy,
                                   optimizer=optimizer,
                                   callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
                                   batch_size=args.batch_size,
                                   num_epochs=num_epochs,
                                   dataset_train=dataset_train,
                                   dataset_valid=dataset_valid,
                                   log_frequency=args.log_frequency,
                                   mutator=mutator)
        if args.visualization:
            trainer.enable_visualization()
        trainer.train()
    else:
        from nni.retiarii.trainer.pytorch.enas import EnasTrainer
        trainer = EnasTrainer(model,
                              loss=criterion,
                              metrics=accuracy,
                              reward_function=reward_accuracy,
                              optimizer=optimizer,
                              batch_size=args.batch_size,
                              num_epochs=num_epochs,
                              dataset=dataset_train,
                              log_frequency=args.log_frequency,
                              ctrl_kwargs=ctrl_kwargs)
        trainer.fit()