search.py 1.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from argparse import ArgumentParser

import datasets
import torch
import torch.nn as nn

from model import SearchCNN
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy


if __name__ == "__main__":
    parser = ArgumentParser("darts")
    parser.add_argument("--layers", default=4, type=int)
    parser.add_argument("--nodes", default=2, type=int)
    parser.add_argument("--batch-size", default=128, type=int)
    parser.add_argument("--log-frequency", default=1, type=int)
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

    model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes)
    criterion = nn.CrossEntropyLoss()

    optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
    n_epochs = 50
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001)

    trainer = DartsTrainer(model,
                           loss=criterion,
                           metrics=lambda output, target: accuracy(output, target, topk=(1,)),
                           model_optim=optim,
                           lr_scheduler=lr_scheduler,
                           num_epochs=50,
                           dataset_train=dataset_train,
                           dataset_valid=dataset_valid,
                           batch_size=args.batch_size,
                           log_frequency=args.log_frequency)
    trainer.train()
    trainer.export()

# augment step
# ...