"driver/include/device.hpp" did not exist on "a414e3fdf83272cbb965c5846f677007c5391d6a"
main.py 6.41 KB
Newer Older
1
import logging
2
3
4
import os
import sys
from argparse import ArgumentParser
5

6
import torch
7
8
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from torchvision import transforms
9

10
import datasets
11
from model import SearchMobileNet
colorjam's avatar
colorjam committed
12
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
13
from putils import LabelSmoothingLoss, accuracy, get_parameters
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from retrain import Retrain

logger = logging.getLogger('nni_proxylessnas')

if __name__ == "__main__":
    parser = ArgumentParser("proxylessnas")
    # configurations of the model
    parser.add_argument("--n_cell_stages", default='4,4,4,4,4,1', type=str)
    parser.add_argument("--stride_stages", default='2,2,2,1,2,1', type=str)
    parser.add_argument("--width_stages", default='24,40,80,96,192,320', type=str)
    parser.add_argument("--bn_momentum", default=0.1, type=float)
    parser.add_argument("--bn_eps", default=1e-3, type=float)
    parser.add_argument("--dropout_rate", default=0, type=float)
    parser.add_argument("--no_decay_keys", default='bn', type=str, choices=[None, 'bn', 'bn#bias'])
    # configurations of imagenet dataset
    parser.add_argument("--data_path", default='/data/imagenet/', type=str)
    parser.add_argument("--train_batch_size", default=256, type=int)
    parser.add_argument("--test_batch_size", default=500, type=int)
    parser.add_argument("--n_worker", default=32, type=int)
    parser.add_argument("--resize_scale", default=0.08, type=float)
    parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
    # configurations for training mode
36
    parser.add_argument("--train_mode", default='search', type=str, choices=['search_v1', 'search', 'retrain'])
37
38
39
    # configurations for search
    parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
    parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
QuanluZhang's avatar
QuanluZhang committed
40
    parser.add_argument("--no-warmup", dest='warmup', action='store_false')
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    # configurations for retrain
    parser.add_argument("--exported_arch_path", default=None, type=str)

    args = parser.parse_args()
    if args.train_mode == 'retrain' and args.exported_arch_path is None:
        logger.error('When --train_mode is retrain, --exported_arch_path must be specified.')
        sys.exit(-1)

    model = SearchMobileNet(width_stages=[int(i) for i in args.width_stages.split(',')],
                            n_cell_stages=[int(i) for i in args.n_cell_stages.split(',')],
                            stride_stages=[int(i) for i in args.stride_stages.split(',')],
                            n_classes=1000,
                            dropout_rate=args.dropout_rate,
                            bn_param=(args.bn_momentum, args.bn_eps))
    logger.info('SearchMobileNet model create done')
    model.init_model()
    logger.info('SearchMobileNet model init done')

    # move network to GPU if available
    if torch.cuda.is_available():
QuanluZhang's avatar
QuanluZhang committed
61
        device = torch.device('cuda')
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    else:
        device = torch.device('cpu')

    logger.info('Creating data provider...')
    data_provider = datasets.ImagenetDataProvider(save_path=args.data_path,
                                                  train_batch_size=args.train_batch_size,
                                                  test_batch_size=args.test_batch_size,
                                                  valid_size=None,
                                                  n_worker=args.n_worker,
                                                  resize_scale=args.resize_scale,
                                                  distort_color=args.distort_color)
    logger.info('Creating data provider done')

    if args.no_decay_keys:
        keys = args.no_decay_keys
        momentum, nesterov = 0.9, True
        optimizer = torch.optim.SGD([
            {'params': get_parameters(model, keys, mode='exclude'), 'weight_decay': 4e-5},
            {'params': get_parameters(model, keys, mode='include'), 'weight_decay': 0},
        ], lr=0.05, momentum=momentum, nesterov=nesterov)
    else:
        optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)

    if args.train_mode == 'search':
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        from nni.retiarii.trainer.pytorch import ProxylessTrainer
        from torchvision.datasets import ImageNet
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        dataset = ImageNet(args.data_path, transform=transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
        trainer = ProxylessTrainer(model,
                                   loss=LabelSmoothingLoss(),
                                   dataset=dataset,
                                   optimizer=optimizer,
                                   metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)),
                                   num_epochs=120,
                                   log_frequency=10)
        trainer.fit()
        print('Final architecture:', trainer.export())
    elif args.train_mode == 'search_v1':
106
107
108
109
110
111
112
        # this is architecture search
        logger.info('Creating ProxylessNasTrainer...')
        trainer = ProxylessNasTrainer(model,
                                      model_optim=optimizer,
                                      train_loader=data_provider.train,
                                      valid_loader=data_provider.valid,
                                      device=device,
QuanluZhang's avatar
QuanluZhang committed
113
                                      warmup=args.warmup,
114
115
116
117
118
119
120
121
122
123
124
125
126
                                      ckpt_path=args.checkpoint_path,
                                      arch_path=args.arch_path)

        logger.info('Start to train with ProxylessNasTrainer...')
        trainer.train()
        logger.info('Training done')
        trainer.export(args.arch_path)
        logger.info('Best architecture exported in %s', args.arch_path)
    elif args.train_mode == 'retrain':
        # this is retrain
        from nni.nas.pytorch.fixed import apply_fixed_architecture
        assert os.path.isfile(args.exported_arch_path), \
            "exported_arch_path {} should be a file.".format(args.exported_arch_path)
127
        apply_fixed_architecture(model, args.exported_arch_path)
128
        trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300)
QuanluZhang's avatar
QuanluZhang committed
129
        trainer.run()