search.py 2.81 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
import json
Chi Song's avatar
Chi Song committed
5
6
import logging
import time
7
8
9
10
11
from argparse import ArgumentParser

import torch
import torch.nn as nn

Chi Song's avatar
Chi Song committed
12
import datasets
13
from model import CNN
14
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
15
16
from utils import accuracy

17

18
logger = logging.getLogger('nni')
19

20
21
if __name__ == "__main__":
    parser = ArgumentParser("darts")
22
    parser.add_argument("--layers", default=8, type=int)
23
    parser.add_argument("--batch-size", default=64, type=int)
24
25
    parser.add_argument("--log-frequency", default=10, type=int)
    parser.add_argument("--epochs", default=50, type=int)
Yuge Zhang's avatar
Yuge Zhang committed
26
    parser.add_argument("--channels", default=16, type=int)
27
    parser.add_argument("--unrolled", default=False, action="store_true")
28
    parser.add_argument("--visualization", default=False, action="store_true")
29
    parser.add_argument("--v1", default=False, action="store_true")
30
31
32
33
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

Yuge Zhang's avatar
Yuge Zhang committed
34
    model = CNN(32, 3, args.channels, 10, args.layers)
35
36
37
    criterion = nn.CrossEntropyLoss()

    optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
38
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    if args.v1:
        from nni.algorithms.nas.pytorch.darts import DartsTrainer
        trainer = DartsTrainer(model,
                               loss=criterion,
                               metrics=lambda output, target: accuracy(output, target, topk=(1,)),
                               optimizer=optim,
                               num_epochs=args.epochs,
                               dataset_train=dataset_train,
                               dataset_valid=dataset_valid,
                               batch_size=args.batch_size,
                               log_frequency=args.log_frequency,
                               unrolled=args.unrolled,
                               callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
        if args.visualization:
            trainer.enable_visualization()

        trainer.train()
    else:
        from nni.retiarii.trainer.pytorch import DartsTrainer
        trainer = DartsTrainer(
            model=model,
            loss=criterion,
            metrics=lambda output, target: accuracy(output, target, topk=(1,)),
            optimizer=optim,
            num_epochs=args.epochs,
            dataset=dataset_train,
            batch_size=args.batch_size,
            log_frequency=args.log_frequency,
            unrolled=args.unrolled
        )
        trainer.fit()
71
        final_architecture = trainer.export()
72
        print('Final architecture:', trainer.export())
73
        json.dump(trainer.export(), open('checkpoint.json', 'w'))