search.py 1.86 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Chi Song's avatar
Chi Song committed
4
5
import logging
import time
6
7
8
9
10
from argparse import ArgumentParser

import torch
import torch.nn as nn

Chi Song's avatar
Chi Song committed
11
import datasets
12
from model import CNN
13
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
14
from nni.nas.pytorch.darts import DartsTrainer
15
16
from utils import accuracy

17
logger = logging.getLogger('nni')
18

19
20
if __name__ == "__main__":
    parser = ArgumentParser("darts")
21
    parser.add_argument("--layers", default=8, type=int)
22
    parser.add_argument("--batch-size", default=64, type=int)
23
24
    parser.add_argument("--log-frequency", default=10, type=int)
    parser.add_argument("--epochs", default=50, type=int)
Yuge Zhang's avatar
Yuge Zhang committed
25
    parser.add_argument("--channels", default=16, type=int)
26
    parser.add_argument("--unrolled", default=False, action="store_true")
27
28
29
30
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

Yuge Zhang's avatar
Yuge Zhang committed
31
    model = CNN(32, 3, args.channels, 10, args.layers)
32
33
34
    criterion = nn.CrossEntropyLoss()

    optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
35
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
36
37
38
39

    trainer = DartsTrainer(model,
                           loss=criterion,
                           metrics=lambda output, target: accuracy(output, target, topk=(1,)),
40
41
                           optimizer=optim,
                           num_epochs=args.epochs,
42
43
44
                           dataset_train=dataset_train,
                           dataset_valid=dataset_valid,
                           batch_size=args.batch_size,
45
                           log_frequency=args.log_frequency,
46
47
                           unrolled=args.unrolled,
                           callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
48
    trainer.train()