search.py 1.72 KB
Newer Older
Chi Song's avatar
Chi Song committed
1
2
import logging
import time
3
4
5
6
7
from argparse import ArgumentParser

import torch
import torch.nn as nn

Chi Song's avatar
Chi Song committed
8
import datasets
9
from model import CNN
10
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
11
from nni.nas.pytorch.darts import DartsTrainer
12
13
from utils import accuracy

14
logger = logging.getLogger('nni')
15

16
17
if __name__ == "__main__":
    parser = ArgumentParser("darts")
18
    parser.add_argument("--layers", default=8, type=int)
19
    parser.add_argument("--batch-size", default=64, type=int)
20
21
    parser.add_argument("--log-frequency", default=10, type=int)
    parser.add_argument("--epochs", default=50, type=int)
22
    parser.add_argument("--unrolled", default=False, action="store_true")
23
24
25
26
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

27
    model = CNN(32, 3, 16, 10, args.layers)
28
29
30
    criterion = nn.CrossEntropyLoss()

    optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
31
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
32
33
34
35

    trainer = DartsTrainer(model,
                           loss=criterion,
                           metrics=lambda output, target: accuracy(output, target, topk=(1,)),
36
37
                           optimizer=optim,
                           num_epochs=args.epochs,
38
39
40
                           dataset_train=dataset_train,
                           dataset_valid=dataset_valid,
                           batch_size=args.batch_size,
41
                           log_frequency=args.log_frequency,
42
43
                           unrolled=args.unrolled,
                           callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
44
    trainer.train()