search.py 2.35 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Chi Song's avatar
Chi Song committed
4
5
import logging
import time
6
7
8
9
10
11
12
13
14
from argparse import ArgumentParser

import torch
import torch.nn as nn

import datasets
from macro import GeneralNetwork
from micro import MicroNetwork
from nni.nas.pytorch import enas
15
16
from nni.nas.pytorch.callbacks import (ArchitectureCheckpoint,
                                       LRSchedulerCallback)
17
18
from utils import accuracy, reward_accuracy

19
logger = logging.getLogger('nni')
Chi Song's avatar
Chi Song committed
20
21


22
23
24
if __name__ == "__main__":
    parser = ArgumentParser("enas")
    parser.add_argument("--batch-size", default=128, type=int)
25
    parser.add_argument("--log-frequency", default=10, type=int)
26
    parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
Yuge Zhang's avatar
Yuge Zhang committed
27
    parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
28
29
30
31
32
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")
    if args.search_for == "macro":
        model = GeneralNetwork()
Yuge Zhang's avatar
Yuge Zhang committed
33
        num_epochs = args.epochs or 310
34
35
36
        mutator = None
    elif args.search_for == "micro":
        model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True)
Yuge Zhang's avatar
Yuge Zhang committed
37
        num_epochs = args.epochs or 150
38
39
40
41
42
43
44
45
46
47
48
49
50
        mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
    else:
        raise AssertionError

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)

    trainer = enas.EnasTrainer(model,
                               loss=criterion,
                               metrics=accuracy,
                               reward_function=reward_accuracy,
                               optimizer=optimizer,
51
                               callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
52
53
54
55
                               batch_size=args.batch_size,
                               num_epochs=num_epochs,
                               dataset_train=dataset_train,
                               dataset_valid=dataset_valid,
56
57
58
                               log_frequency=args.log_frequency,
                               mutator=mutator)
    trainer.train()