search.py 2.69 KB
Newer Older
Chi Song's avatar
Chi Song committed
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import sys
import time
from argparse import ArgumentParser

import torch
import torch.nn as nn

from nni.nas.pytorch.callbacks import ArchitectureCheckpoint
colorjam's avatar
colorjam committed
13
from nni.algorithms.nas.pytorch.pdarts import PdartsTrainer
Chi Song's avatar
Chi Song committed
14
15
16
17
18
19
20
21
22

# prevent it to be reordered.
if True:
    sys.path.append('../darts')
    from utils import accuracy
    from model import CNN
    import datasets


23
logger = logging.getLogger('nni')
Chi Song's avatar
Chi Song committed
24
25
26
27


if __name__ == "__main__":
    parser = ArgumentParser("pdarts")
Yuge Zhang's avatar
Yuge Zhang committed
28
29
30
31
    parser.add_argument('--add_layers', action='append', type=int,
                        help='add layers, default: [0, 6, 12]')
    parser.add_argument('--dropped_ops', action='append', type=int,
                        help='drop ops, default: [3, 2, 1]')
Chi Song's avatar
Chi Song committed
32
    parser.add_argument("--nodes", default=4, type=int)
Chi Song's avatar
Chi Song committed
33
    parser.add_argument("--init_layers", default=5, type=int)
Yuge Zhang's avatar
Yuge Zhang committed
34
    parser.add_argument("--channels", default=16, type=int)
Chi Song's avatar
Chi Song committed
35
36
37
    parser.add_argument("--batch-size", default=64, type=int)
    parser.add_argument("--log-frequency", default=1, type=int)
    parser.add_argument("--epochs", default=50, type=int)
Chi Song's avatar
Chi Song committed
38
    parser.add_argument("--unrolled", default=False, action="store_true")
Chi Song's avatar
Chi Song committed
39
    args = parser.parse_args()
Yuge Zhang's avatar
Yuge Zhang committed
40
41
42
43
    if args.add_layers is None:
        args.add_layers = [0, 6, 12]
    if args.dropped_ops is None:
        args.dropped_ops = [3, 2, 1]
Chi Song's avatar
Chi Song committed
44
45
46
47
48

    logger.info("loading data")
    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

    def model_creator(layers):
Yuge Zhang's avatar
Yuge Zhang committed
49
        model = CNN(32, 3, args.channels, 10, layers, n_nodes=args.nodes)
Chi Song's avatar
Chi Song committed
50
51
52
53
54
55
56
57
58
        criterion = nn.CrossEntropyLoss()

        optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)

        return model, criterion, optim, lr_scheduler

    logger.info("initializing trainer")
    trainer = PdartsTrainer(model_creator,
Chi Song's avatar
Chi Song committed
59
                            init_layers=args.init_layers,
Chi Song's avatar
Chi Song committed
60
                            metrics=lambda output, target: accuracy(output, target, topk=(1,)),
Chi Song's avatar
Chi Song committed
61
62
                            pdarts_num_layers=args.add_layers,
                            pdarts_num_to_drop=args.dropped_ops,
Chi Song's avatar
Chi Song committed
63
64
65
66
67
                            num_epochs=args.epochs,
                            dataset_train=dataset_train,
                            dataset_valid=dataset_valid,
                            batch_size=args.batch_size,
                            log_frequency=args.log_frequency,
Chi Song's avatar
Chi Song committed
68
                            unrolled=args.unrolled,
Chi Song's avatar
Chi Song committed
69
70
71
                            callbacks=[ArchitectureCheckpoint("./checkpoints")])
    logger.info("training")
    trainer.train()