search.py 2.81 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
import json
Chi Song's avatar
Chi Song committed
5
6
import logging
import time
7
8
9
10
11
from argparse import ArgumentParser

import torch
import torch.nn as nn

Chi Song's avatar
Chi Song committed
12
import datasets
13
from model import CNN
14
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
15
16
from utils import accuracy

17

18
logger = logging.getLogger('nni')
19

20
21
if __name__ == "__main__":
    parser = ArgumentParser("darts")
22
    parser.add_argument("--layers", default=8, type=int)
23
    parser.add_argument("--batch-size", default=64, type=int)
24
25
    parser.add_argument("--log-frequency", default=10, type=int)
    parser.add_argument("--epochs", default=50, type=int)
Yuge Zhang's avatar
Yuge Zhang committed
26
    parser.add_argument("--channels", default=16, type=int)
27
    parser.add_argument("--unrolled", default=False, action="store_true")
28
    parser.add_argument("--visualization", default=False, action="store_true")
29
    parser.add_argument("--v1", default=False, action="store_true")
30
31
32
33
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

Yuge Zhang's avatar
Yuge Zhang committed
34
    model = CNN(32, 3, args.channels, 10, args.layers)
35
36
37
    criterion = nn.CrossEntropyLoss()

    optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
38
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    if args.v1:
        from nni.algorithms.nas.pytorch.darts import DartsTrainer
        trainer = DartsTrainer(model,
                               loss=criterion,
                               metrics=lambda output, target: accuracy(output, target, topk=(1,)),
                               optimizer=optim,
                               num_epochs=args.epochs,
                               dataset_train=dataset_train,
                               dataset_valid=dataset_valid,
                               batch_size=args.batch_size,
                               log_frequency=args.log_frequency,
                               unrolled=args.unrolled,
                               callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
        if args.visualization:
            trainer.enable_visualization()

        trainer.train()
    else:
58
        from nni.retiarii.oneshot.pytorch import DartsTrainer
59
60
61
62
63
64
65
66
67
68
69
70
        trainer = DartsTrainer(
            model=model,
            loss=criterion,
            metrics=lambda output, target: accuracy(output, target, topk=(1,)),
            optimizer=optim,
            num_epochs=args.epochs,
            dataset=dataset_train,
            batch_size=args.batch_size,
            log_frequency=args.log_frequency,
            unrolled=args.unrolled
        )
        trainer.fit()
71
        final_architecture = trainer.export()
72
        print('Final architecture:', trainer.export())
73
        json.dump(trainer.export(), open('checkpoint.json', 'w'))