main.py 2.39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from argparse import ArgumentParser

import datasets
import torch
import torch.nn as nn
import nni.nas.pytorch as nas
from nni.nas.pytorch.pdarts import PdartsTrainer
from nni.nas.pytorch.darts import CnnNetwork, CnnCell


def accuracy(output, target, topk=(1,)):
    """ Computes the precision@k for the specified values of k """
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    # one-hot case
    if target.ndimension() > 1:
        target = target.max(1)[1]

    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = dict()
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
    return res


if __name__ == "__main__":
    parser = ArgumentParser("darts")
    parser.add_argument("--layers", default=5, type=int)
    parser.add_argument('--add_layers', action='append',
                        default=[0, 6, 12], help='add layers')
    parser.add_argument("--nodes", default=4, type=int)
    parser.add_argument("--batch-size", default=128, type=int)
    parser.add_argument("--log-frequency", default=1, type=int)
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

    def model_creator(layers, n_nodes):
        model = CnnNetwork(3, 16, 10, layers, n_nodes=n_nodes, cell_type=CnnCell)
        loss = nn.CrossEntropyLoss()

        model_optim = torch.optim.SGD(model.parameters(), 0.025,
                                      momentum=0.9, weight_decay=3.0E-4)
        n_epochs = 50
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, n_epochs, eta_min=0.001)
        return model, loss, model_optim, lr_scheduler

    trainer = PdartsTrainer(model_creator,
                            metrics=lambda output, target: accuracy(output, target, topk=(1,)),
                            num_epochs=50,
                            pdarts_num_layers=[0, 6, 12],
                            pdarts_num_to_drop=[3, 2, 2],
                            dataset_train=dataset_train,
                            dataset_valid=dataset_valid,
                            layers=args.layers,
                            n_nodes=args.nodes,
                            batch_size=args.batch_size,
                            log_frequency=args.log_frequency)
    trainer.train()
    trainer.export()