"mmdet3d/structures/ops/iou3d_calculator.py" did not exist on "ba492be7ea3aa5dbae420a190377496127e767b9"
search.py 2.69 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Chi Song's avatar
Chi Song committed
4
5
import logging
import time
6
7
8
9
10
from argparse import ArgumentParser

import torch
import torch.nn as nn

Chi Song's avatar
Chi Song committed
11
import datasets
12
from model import CNN
13
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
14
15
from utils import accuracy

16

17
logger = logging.getLogger('nni')
18

19
20
if __name__ == "__main__":
    parser = ArgumentParser("darts")
21
    parser.add_argument("--layers", default=8, type=int)
22
    parser.add_argument("--batch-size", default=64, type=int)
23
24
    parser.add_argument("--log-frequency", default=10, type=int)
    parser.add_argument("--epochs", default=50, type=int)
Yuge Zhang's avatar
Yuge Zhang committed
25
    parser.add_argument("--channels", default=16, type=int)
26
    parser.add_argument("--unrolled", default=False, action="store_true")
27
    parser.add_argument("--visualization", default=False, action="store_true")
28
    parser.add_argument("--v1", default=False, action="store_true")
29
30
31
32
    args = parser.parse_args()

    dataset_train, dataset_valid = datasets.get_dataset("cifar10")

Yuge Zhang's avatar
Yuge Zhang committed
33
    model = CNN(32, 3, args.channels, 10, args.layers)
34
35
36
    criterion = nn.CrossEntropyLoss()

    optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
37
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    if args.v1:
        from nni.algorithms.nas.pytorch.darts import DartsTrainer
        trainer = DartsTrainer(model,
                               loss=criterion,
                               metrics=lambda output, target: accuracy(output, target, topk=(1,)),
                               optimizer=optim,
                               num_epochs=args.epochs,
                               dataset_train=dataset_train,
                               dataset_valid=dataset_valid,
                               batch_size=args.batch_size,
                               log_frequency=args.log_frequency,
                               unrolled=args.unrolled,
                               callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
        if args.visualization:
            trainer.enable_visualization()

        trainer.train()
    else:
        from nni.retiarii.trainer.pytorch import DartsTrainer
        trainer = DartsTrainer(
            model=model,
            loss=criterion,
            metrics=lambda output, target: accuracy(output, target, topk=(1,)),
            optimizer=optim,
            num_epochs=args.epochs,
            dataset=dataset_train,
            batch_size=args.batch_size,
            log_frequency=args.log_frequency,
            unrolled=args.unrolled
        )
        trainer.fit()
        print('Final architecture:', trainer.export())