search.py 2.45 KB
Newer Older
Chi Song's avatar
Chi Song committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import sys
import time
from argparse import ArgumentParser

import torch
import torch.nn as nn

from nni.nas.pytorch.callbacks import ArchitectureCheckpoint
from nni.nas.pytorch.pdarts import PdartsTrainer

# prevent it to be reordered.
if True:
    sys.path.append('../darts')
    from utils import accuracy
    from model import CNN
    import datasets


23
logger = logging.getLogger('nni')
Chi Song's avatar
Chi Song committed
24
25
26
27
28
29


if __name__ == "__main__":
    parser = ArgumentParser("pdarts")
    parser.add_argument('--add_layers', action='append',
                        default=[0, 6, 12], help='add layers')
Chi Song's avatar
Chi Song committed
30
31
    parser.add_argument('--dropped_ops', action='append',
                        default=[3, 2, 1], help='drop ops')
Chi Song's avatar
Chi Song committed
32
    parser.add_argument("--nodes", default=4, type=int)
Chi Song's avatar
Chi Song committed
33
    parser.add_argument("--init_layers", default=5, type=int)
Chi Song's avatar
Chi Song committed
34
35
36
    parser.add_argument("--batch-size", default=64, type=int)
    parser.add_argument("--log-frequency", default=1, type=int)
    parser.add_argument("--epochs", default=50, type=int)
Chi Song's avatar
Chi Song committed
37
    parser.add_argument("--unrolled", default=False, action="store_true")
Chi Song's avatar
Chi Song committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    args = parser.parse_args()

    logger.info("loading data")
    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

    def model_creator(layers):
        model = CNN(32, 3, 16, 10, layers, n_nodes=args.nodes)
        criterion = nn.CrossEntropyLoss()

        optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)

        return model, criterion, optim, lr_scheduler

    logger.info("initializing trainer")
    trainer = PdartsTrainer(model_creator,
Chi Song's avatar
Chi Song committed
54
                            init_layers=args.init_layers,
Chi Song's avatar
Chi Song committed
55
                            metrics=lambda output, target: accuracy(output, target, topk=(1,)),
Chi Song's avatar
Chi Song committed
56
57
                            pdarts_num_layers=args.add_layers,
                            pdarts_num_to_drop=args.dropped_ops,
Chi Song's avatar
Chi Song committed
58
59
60
61
62
                            num_epochs=args.epochs,
                            dataset_train=dataset_train,
                            dataset_valid=dataset_valid,
                            batch_size=args.batch_size,
                            log_frequency=args.log_frequency,
Chi Song's avatar
Chi Song committed
63
                            unrolled=args.unrolled,
Chi Song's avatar
Chi Song committed
64
65
66
                            callbacks=[ArchitectureCheckpoint("./checkpoints")])
    logger.info("training")
    trainer.train()