main.py 7.34 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
import json
2
import logging
3
4
5
import os
import sys
from argparse import ArgumentParser
6

7
import torch
8
from torchvision import transforms
9
from nni.retiarii.fixed import fixed_arch
10

11
import datasets
12
from model import SearchMobileNet
13
from putils import LabelSmoothingLoss, accuracy, get_parameters
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from retrain import Retrain

logger = logging.getLogger('nni_proxylessnas')

if __name__ == "__main__":
    parser = ArgumentParser("proxylessnas")
    # configurations of the model
    parser.add_argument("--n_cell_stages", default='4,4,4,4,4,1', type=str)
    parser.add_argument("--stride_stages", default='2,2,2,1,2,1', type=str)
    parser.add_argument("--width_stages", default='24,40,80,96,192,320', type=str)
    parser.add_argument("--bn_momentum", default=0.1, type=float)
    parser.add_argument("--bn_eps", default=1e-3, type=float)
    parser.add_argument("--dropout_rate", default=0, type=float)
    parser.add_argument("--no_decay_keys", default='bn', type=str, choices=[None, 'bn', 'bn#bias'])
28
29
30
31
32
33
    parser.add_argument('--grad_reg_loss_type', default='add#linear', type=str, choices=['add#linear', 'mul#log'])
    parser.add_argument('--grad_reg_loss_lambda', default=1e-1, type=float)  # grad_reg_loss_params
    parser.add_argument('--grad_reg_loss_alpha', default=0.2, type=float)  # grad_reg_loss_params
    parser.add_argument('--grad_reg_loss_beta',  default=0.3, type=float)  # grad_reg_loss_params
    parser.add_argument("--applied_hardware", default=None, type=str, help='the hardware to predict model latency')
    parser.add_argument("--reference_latency", default=None, type=float, help='the reference latency in specified hardware')
34
35
36
37
38
39
40
41
    # configurations of imagenet dataset
    parser.add_argument("--data_path", default='/data/imagenet/', type=str)
    parser.add_argument("--train_batch_size", default=256, type=int)
    parser.add_argument("--test_batch_size", default=500, type=int)
    parser.add_argument("--n_worker", default=32, type=int)
    parser.add_argument("--resize_scale", default=0.08, type=float)
    parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
    # configurations for training mode
42
    parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain'])
43
44
45
    # configurations for search
    parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
    parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
QuanluZhang's avatar
QuanluZhang committed
46
    parser.add_argument("--no-warmup", dest='warmup', action='store_false')
47
48
49
50
51
52
53
54
    # configurations for retrain
    parser.add_argument("--exported_arch_path", default=None, type=str)

    args = parser.parse_args()
    if args.train_mode == 'retrain' and args.exported_arch_path is None:
        logger.error('When --train_mode is retrain, --exported_arch_path must be specified.')
        sys.exit(-1)

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    if args.train_mode == 'retrain':
        assert os.path.isfile(args.exported_arch_path), \
            "exported_arch_path {} should be a file.".format(args.exported_arch_path)
        with fixed_arch(args.exported_arch_path):
            model = SearchMobileNet(width_stages=[int(i) for i in args.width_stages.split(',')],
                                    n_cell_stages=[int(i) for i in args.n_cell_stages.split(',')],
                                    stride_stages=[int(i) for i in args.stride_stages.split(',')],
                                    n_classes=1000,
                                    dropout_rate=args.dropout_rate,
                                    bn_param=(args.bn_momentum, args.bn_eps))
    else:
        model = SearchMobileNet(width_stages=[int(i) for i in args.width_stages.split(',')],
                                n_cell_stages=[int(i) for i in args.n_cell_stages.split(',')],
                                stride_stages=[int(i) for i in args.stride_stages.split(',')],
                                n_classes=1000,
                                dropout_rate=args.dropout_rate,
                                bn_param=(args.bn_momentum, args.bn_eps))
72
73
74
75
76
77
    logger.info('SearchMobileNet model create done')
    model.init_model()
    logger.info('SearchMobileNet model init done')

    # move network to GPU if available
    if torch.cuda.is_available():
QuanluZhang's avatar
QuanluZhang committed
78
        device = torch.device('cuda')
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    else:
        device = torch.device('cpu')

    logger.info('Creating data provider...')
    data_provider = datasets.ImagenetDataProvider(save_path=args.data_path,
                                                  train_batch_size=args.train_batch_size,
                                                  test_batch_size=args.test_batch_size,
                                                  valid_size=None,
                                                  n_worker=args.n_worker,
                                                  resize_scale=args.resize_scale,
                                                  distort_color=args.distort_color)
    logger.info('Creating data provider done')

    if args.no_decay_keys:
        keys = args.no_decay_keys
        momentum, nesterov = 0.9, True
        optimizer = torch.optim.SGD([
            {'params': get_parameters(model, keys, mode='exclude'), 'weight_decay': 4e-5},
            {'params': get_parameters(model, keys, mode='include'), 'weight_decay': 0},
        ], lr=0.05, momentum=momentum, nesterov=nesterov)
    else:
100
        momentum, nesterov = 0.9, True
101
102
        optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)

103
104
105
106
107
108
109
110
111
112
    if args.grad_reg_loss_type == 'add#linear':
        grad_reg_loss_params = {'lambda': args.grad_reg_loss_lambda}
    elif args.grad_reg_loss_type == 'mul#log':
        grad_reg_loss_params = {
            'alpha': args.grad_reg_loss_alpha,
            'beta': args.grad_reg_loss_beta,
        }
    else:
        args.grad_reg_loss_params = None

113
    if args.train_mode == 'search':
114
        from nni.retiarii.oneshot.pytorch import ProxylessTrainer
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        from torchvision.datasets import ImageNet
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        dataset = ImageNet(args.data_path, transform=transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
        trainer = ProxylessTrainer(model,
                                   loss=LabelSmoothingLoss(),
                                   dataset=dataset,
                                   optimizer=optimizer,
                                   metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)),
                                   num_epochs=120,
130
131
132
133
134
                                   log_frequency=10,
                                   grad_reg_loss_type=args.grad_reg_loss_type, 
                                   grad_reg_loss_params=grad_reg_loss_params, 
                                   applied_hardware=args.applied_hardware, dummy_input=(1, 3, 224, 224),
                                   ref_latency=args.reference_latency)
135
136
        trainer.fit()
        print('Final architecture:', trainer.export())
Yuge Zhang's avatar
Yuge Zhang committed
137
        json.dump(trainer.export(), open('checkpoint.json', 'w'))
138
139
140
    elif args.train_mode == 'retrain':
        # this is retrain
        trainer = Retrain(model, optimizer, device, data_provider, n_epochs=300)
QuanluZhang's avatar
QuanluZhang committed
141
        trainer.run()