"examples/community/img2img_inpainting.py" did not exist on "8aac1f99d7af5873db7d23c07fba370d0f5061a6"
search.py 1.56 KB
Newer Older
1
2
from argparse import ArgumentParser

3
import datasets
4
5
6
import torch
import torch.nn as nn

7
8
9
from model import CNN
from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint
from nni.nas.pytorch.darts import DartsTrainer
10
11
from utils import accuracy

12

13
14
if __name__ == "__main__":
    parser = ArgumentParser("darts")
15
16
17
18
    parser.add_argument("--layers", default=8, type=int)
    parser.add_argument("--batch-size", default=96, type=int)
    parser.add_argument("--log-frequency", default=10, type=int)
    parser.add_argument("--epochs", default=50, type=int)
19
20
21
22
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

23
    model = CNN(32, 3, 16, 10, args.layers)
24
25
26
    criterion = nn.CrossEntropyLoss()

    optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
27
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
28
29
30
31

    trainer = DartsTrainer(model,
                           loss=criterion,
                           metrics=lambda output, target: accuracy(output, target, topk=(1,)),
32
33
                           optimizer=optim,
                           num_epochs=args.epochs,
34
35
36
                           dataset_train=dataset_train,
                           dataset_valid=dataset_valid,
                           batch_size=args.batch_size,
37
38
39
                           log_frequency=args.log_frequency,
                           callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
    trainer.train_and_validate()