search.py 2.68 KB
Newer Older
Chi Song's avatar
Chi Song committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import sys
import time
from argparse import ArgumentParser

import torch
import torch.nn as nn

from nni.nas.pytorch.callbacks import ArchitectureCheckpoint
from nni.nas.pytorch.pdarts import PdartsTrainer

# prevent it to be reordered.
if True:
    sys.path.append('../darts')
    from utils import accuracy
    from model import CNN
    import datasets


23
logger = logging.getLogger('nni')
Chi Song's avatar
Chi Song committed
24
25
26
27


if __name__ == "__main__":
    parser = ArgumentParser("pdarts")
Yuge Zhang's avatar
Yuge Zhang committed
28
29
30
31
    parser.add_argument('--add_layers', action='append', type=int,
                        help='add layers, default: [0, 6, 12]')
    parser.add_argument('--dropped_ops', action='append', type=int,
                        help='drop ops, default: [3, 2, 1]')
Chi Song's avatar
Chi Song committed
32
    parser.add_argument("--nodes", default=4, type=int)
Chi Song's avatar
Chi Song committed
33
    parser.add_argument("--init_layers", default=5, type=int)
Yuge Zhang's avatar
Yuge Zhang committed
34
    parser.add_argument("--channels", default=16, type=int)
Chi Song's avatar
Chi Song committed
35
36
37
    parser.add_argument("--batch-size", default=64, type=int)
    parser.add_argument("--log-frequency", default=1, type=int)
    parser.add_argument("--epochs", default=50, type=int)
Chi Song's avatar
Chi Song committed
38
    parser.add_argument("--unrolled", default=False, action="store_true")
Chi Song's avatar
Chi Song committed
39
    args = parser.parse_args()
Yuge Zhang's avatar
Yuge Zhang committed
40
41
42
43
    if args.add_layers is None:
        args.add_layers = [0, 6, 12]
    if args.dropped_ops is None:
        args.dropped_ops = [3, 2, 1]
Chi Song's avatar
Chi Song committed
44
45
46
47
48

    logger.info("loading data")
    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

    def model_creator(layers):
Yuge Zhang's avatar
Yuge Zhang committed
49
        model = CNN(32, 3, args.channels, 10, layers, n_nodes=args.nodes)
Chi Song's avatar
Chi Song committed
50
51
52
53
54
55
56
57
58
        criterion = nn.CrossEntropyLoss()

        optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)

        return model, criterion, optim, lr_scheduler

    logger.info("initializing trainer")
    trainer = PdartsTrainer(model_creator,
Chi Song's avatar
Chi Song committed
59
                            init_layers=args.init_layers,
Chi Song's avatar
Chi Song committed
60
                            metrics=lambda output, target: accuracy(output, target, topk=(1,)),
Chi Song's avatar
Chi Song committed
61
62
                            pdarts_num_layers=args.add_layers,
                            pdarts_num_to_drop=args.dropped_ops,
Chi Song's avatar
Chi Song committed
63
64
65
66
67
                            num_epochs=args.epochs,
                            dataset_train=dataset_train,
                            dataset_valid=dataset_valid,
                            batch_size=args.batch_size,
                            log_frequency=args.log_frequency,
Chi Song's avatar
Chi Song committed
68
                            unrolled=args.unrolled,
Chi Song's avatar
Chi Song committed
69
70
71
                            callbacks=[ArchitectureCheckpoint("./checkpoints")])
    logger.info("training")
    trainer.train()