main.py 5.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
import os
import sys
import logging
from argparse import ArgumentParser
import torch
import datasets

from putils import get_parameters
from model import SearchMobileNet
colorjam's avatar
colorjam committed
10
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from retrain import Retrain

logger = logging.getLogger('nni_proxylessnas')

if __name__ == "__main__":
    parser = ArgumentParser("proxylessnas")
    # configurations of the model
    parser.add_argument("--n_cell_stages", default='4,4,4,4,4,1', type=str)
    parser.add_argument("--stride_stages", default='2,2,2,1,2,1', type=str)
    parser.add_argument("--width_stages", default='24,40,80,96,192,320', type=str)
    parser.add_argument("--bn_momentum", default=0.1, type=float)
    parser.add_argument("--bn_eps", default=1e-3, type=float)
    parser.add_argument("--dropout_rate", default=0, type=float)
    parser.add_argument("--no_decay_keys", default='bn', type=str, choices=[None, 'bn', 'bn#bias'])
    # configurations of imagenet dataset
    parser.add_argument("--data_path", default='/data/imagenet/', type=str)
    parser.add_argument("--train_batch_size", default=256, type=int)
    parser.add_argument("--test_batch_size", default=500, type=int)
    parser.add_argument("--n_worker", default=32, type=int)
    parser.add_argument("--resize_scale", default=0.08, type=float)
    parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
    # configurations for training mode
    parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain'])
    # configurations for search
    parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
    parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
QuanluZhang's avatar
QuanluZhang committed
37
    parser.add_argument("--no-warmup", dest='warmup', action='store_false')
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    # configurations for retrain
    parser.add_argument("--exported_arch_path", default=None, type=str)

    args = parser.parse_args()
    if args.train_mode == 'retrain' and args.exported_arch_path is None:
        logger.error('When --train_mode is retrain, --exported_arch_path must be specified.')
        sys.exit(-1)

    model = SearchMobileNet(width_stages=[int(i) for i in args.width_stages.split(',')],
                            n_cell_stages=[int(i) for i in args.n_cell_stages.split(',')],
                            stride_stages=[int(i) for i in args.stride_stages.split(',')],
                            n_classes=1000,
                            dropout_rate=args.dropout_rate,
                            bn_param=(args.bn_momentum, args.bn_eps))
    logger.info('SearchMobileNet model create done')
    model.init_model()
    logger.info('SearchMobileNet model init done')

    # move network to GPU if available
    if torch.cuda.is_available():
QuanluZhang's avatar
QuanluZhang committed
58
        device = torch.device('cuda')
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
    else:
        device = torch.device('cpu')

    logger.info('Creating data provider...')
    data_provider = datasets.ImagenetDataProvider(save_path=args.data_path,
                                                  train_batch_size=args.train_batch_size,
                                                  test_batch_size=args.test_batch_size,
                                                  valid_size=None,
                                                  n_worker=args.n_worker,
                                                  resize_scale=args.resize_scale,
                                                  distort_color=args.distort_color)
    logger.info('Creating data provider done')

    if args.no_decay_keys:
        keys = args.no_decay_keys
        momentum, nesterov = 0.9, True
        optimizer = torch.optim.SGD([
            {'params': get_parameters(model, keys, mode='exclude'), 'weight_decay': 4e-5},
            {'params': get_parameters(model, keys, mode='include'), 'weight_decay': 0},
        ], lr=0.05, momentum=momentum, nesterov=nesterov)
    else:
        optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)

    if args.train_mode == 'search':
        # this is architecture search
        logger.info('Creating ProxylessNasTrainer...')
        trainer = ProxylessNasTrainer(model,
                                      model_optim=optimizer,
                                      train_loader=data_provider.train,
                                      valid_loader=data_provider.valid,
                                      device=device,
QuanluZhang's avatar
QuanluZhang committed
90
                                      warmup=args.warmup,
91
92
93
94
95
96
97
98
99
100
101
102
103
                                      ckpt_path=args.checkpoint_path,
                                      arch_path=args.arch_path)

        logger.info('Start to train with ProxylessNasTrainer...')
        trainer.train()
        logger.info('Training done')
        trainer.export(args.arch_path)
        logger.info('Best architecture exported in %s', args.arch_path)
    elif args.train_mode == 'retrain':
        # this is retrain
        from nni.nas.pytorch.fixed import apply_fixed_architecture
        assert os.path.isfile(args.exported_arch_path), \
            "exported_arch_path {} should be a file.".format(args.exported_arch_path)
104
        apply_fixed_architecture(model, args.exported_arch_path)
105
        trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300)
QuanluZhang's avatar
QuanluZhang committed
106
        trainer.run()