search.py 2.03 KB
Newer Older
Chi Song's avatar
Chi Song committed
1
2
import logging
import time
3
4
5
6
7
from argparse import ArgumentParser

import torch
import torch.nn as nn

Chi Song's avatar
Chi Song committed
8
import datasets
9
from model import CNN
10
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
11
from nni.nas.pytorch.darts import DartsTrainer
12
13
from utils import accuracy

Chi Song's avatar
Chi Song committed
14
15
16
17
18
19
20
21
22
23
logger = logging.getLogger()

fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, '%m/%d/%Y, %I:%M:%S %p')

std_out_info = logging.StreamHandler()
std_out_info.setFormatter(formatter)
logger.setLevel(logging.INFO)
logger.addHandler(std_out_info)
24

25
26
if __name__ == "__main__":
    parser = ArgumentParser("darts")
27
    parser.add_argument("--layers", default=8, type=int)
28
    parser.add_argument("--batch-size", default=64, type=int)
29
30
    parser.add_argument("--log-frequency", default=10, type=int)
    parser.add_argument("--epochs", default=50, type=int)
31
    parser.add_argument("--unrolled", default=False, action="store_true")
32
33
34
35
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

36
    model = CNN(32, 3, 16, 10, args.layers)
37
38
39
    criterion = nn.CrossEntropyLoss()

    optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
40
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
41
42
43
44

    trainer = DartsTrainer(model,
                           loss=criterion,
                           metrics=lambda output, target: accuracy(output, target, topk=(1,)),
45
46
                           optimizer=optim,
                           num_epochs=args.epochs,
47
48
49
                           dataset_train=dataset_train,
                           dataset_valid=dataset_valid,
                           batch_size=args.batch_size,
50
                           log_frequency=args.log_frequency,
51
52
                           unrolled=args.unrolled,
                           callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
53
    trainer.train()