"vscode:/vscode.git/clone" did not exist on "5c08a36cbfaeefab461ef7c42d897acae568b97a"
main.py 12.8 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree 
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

from __future__ import print_function
Hang Zhang's avatar
Hang Zhang committed
12
import os
Hang Zhang's avatar
Hang Zhang committed
13
import argparse
Hang Zhang's avatar
Hang Zhang committed
14
from tqdm import tqdm
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
15

Hang Zhang's avatar
Hang Zhang committed
16
17
18
import torch
import torch.nn as nn

Hang Zhang's avatar
Hang Zhang committed
19
import encoding
Hang Zhang's avatar
Hang Zhang committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from encoding.nn import LabelSmoothing, NLLMultiLabelSmooth
from encoding.utils import (accuracy, AverageMeter, MixUpWrapper, LR_Scheduler)

class Options():
    def __init__(self):
        # data settings
        parser = argparse.ArgumentParser(description='Deep Encoding')
        parser.add_argument('--dataset', type=str, default='cifar10',
                            help='training dataset (default: cifar10)')
        parser.add_argument('--base-size', type=int, default=None,
                            help='base image size')
        parser.add_argument('--crop-size', type=int, default=224,
                            help='crop image size')
        parser.add_argument('--label-smoothing', type=float, default=0.0,
                            help='label-smoothing (default eta: 0.0)')
        parser.add_argument('--mixup', type=float, default=0.0,
                            help='mixup (default eta: 0.0)')
        parser.add_argument('--rand-aug', action='store_true', 
                            default=False, help='rectify convolution')
        # model params 
        parser.add_argument('--model', type=str, default='densenet',
                            help='network model type (default: densenet)')
        parser.add_argument('--pretrained', action='store_true', 
                            default=False, help='load pretrianed mode')
        parser.add_argument('--rectify', action='store_true', 
                            default=False, help='rectify convolution')
        parser.add_argument('--rectify-avg', action='store_true', 
                            default=False, help='rectify convolution')
        parser.add_argument('--last-gamma', action='store_true', default=False,
                            help='whether to init gamma of the last BN layer in \
                            each bottleneck to 0 (default: False)')
        parser.add_argument('--dropblock-prob', type=float, default=0,
                            help='DropBlock prob. default is 0.')
        parser.add_argument('--final-drop', type=float, default=0,
                            help='final dropout prob. default is 0.')
        # training hyper params
        parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                            help='batch size for training (default: 128)')
        parser.add_argument('--test-batch-size', type=int, default=256, metavar='N',
                            help='batch size for testing (default: 256)')
        parser.add_argument('--epochs', type=int, default=120, metavar='N',
                            help='number of epochs to train (default: 600)')
        parser.add_argument('--start_epoch', type=int, default=0, 
                            metavar='N', help='the epoch number to start (default: 1)')
        parser.add_argument('--workers', type=int, default=32,
                            metavar='N', help='dataloader threads')
        # optimizer
        parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                            help='learning rate (default: 0.1)')
        parser.add_argument('--lr-scheduler', type=str, default='cos', 
                            help='learning rate scheduler (default: cos)')
        parser.add_argument('--warmup-epochs', type=int, default=0,
                            help='number of warmup epochs (default: 0)')
        parser.add_argument('--lr-step', type=int, default=40, metavar='LR',
                            help='learning rate step (default: 40)')
        parser.add_argument('--momentum', type=float, default=0.9, 
                            metavar='M', help='SGD momentum (default: 0.9)')
        parser.add_argument('--weight-decay', type=float, default=1e-4, 
                            metavar ='M', help='SGD weight decay (default: 1e-4)')
        parser.add_argument('--no-bn-wd', action='store_true', 
                            default=False, help='no bias decay')
        # cuda, seed and logging
        parser.add_argument('--no-cuda', action='store_true', 
                            default=False, help='disables CUDA training')
        parser.add_argument('--seed', type=int, default=1, metavar='S',
                            help='random seed (default: 1)')
        # checking point
        parser.add_argument('--resume', type=str, default=None,
                            help='put the path to resuming file if needed')
        parser.add_argument('--checkname', type=str, default='default',
                            help='set the checkpoint name')
        # evaluation option
        parser.add_argument('--eval', action='store_true', default= False,
                            help='evaluating')
        parser.add_argument('--export', type=str, default=None,
                            help='put the path to resuming file if needed')
        self.parser = parser

    def parse(self):
        args = self.parser.parse_args()
        return args
Hang Zhang's avatar
Hang Zhang committed
101

Hang Zhang's avatar
Hang Zhang committed
102
# global variable
Hang Zhang's avatar
Hang Zhang committed
103
104
105
best_pred = 0.0
acclist_train = []
acclist_val = []
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
106

Hang Zhang's avatar
Hang Zhang committed
107
def main():
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
108
    # init the args
Hang Zhang's avatar
Hang Zhang committed
109
    global best_pred, acclist_train, acclist_val
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
110
111
    args = Options().parse()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
Hang Zhang's avatar
Hang Zhang committed
112
    print(args)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
113
114
115
116
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    # init dataloader
Hang Zhang's avatar
Hang Zhang committed
117
118
    transform_train, transform_val = encoding.transforms.get_transform(
            args.dataset, args.base_size, args.crop_size, args.rand_aug)
Hang Zhang's avatar
Hang Zhang committed
119
120
121
122
123
124
    trainset = encoding.datasets.get_dataset(args.dataset, root=os.path.expanduser('~/.encoding/data'),
                                             transform=transform_train, train=True, download=True)
    valset = encoding.datasets.get_dataset(args.dataset, root=os.path.expanduser('~/.encoding/data'),
                                           transform=transform_val, train=False, download=True)
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batch_size, shuffle=True,
Hang Zhang's avatar
Hang Zhang committed
125
        num_workers=args.workers, drop_last=True, pin_memory=True)
Hang Zhang's avatar
Hang Zhang committed
126
127
128
129
130

    val_loader = torch.utils.data.DataLoader(
        valset, batch_size=args.test_batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
131
    # init the model
Hang Zhang's avatar
Hang Zhang committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    model_kwargs = {}
    if args.pretrained:
        model_kwargs['pretrained'] = True

    if args.final_drop > 0.0:
        model_kwargs['final_drop'] = args.final_drop

    if args.dropblock_prob > 0.0:
        model_kwargs['dropblock_prob'] = args.dropblock_prob

    if args.last_gamma:
        model_kwargs['last_gamma'] = True

    if args.rectify:
        model_kwargs['rectified_conv'] = True
        model_kwargs['rectify_avg'] = args.rectify_avg

    model = encoding.models.get_model(args.model, **model_kwargs)
    if args.dropblock_prob > 0.0:
        from functools import partial
        from encoding.nn import reset_dropblock
        nr_iters = (args.epochs - 2 * args.warmup_epochs) * len(train_loader)
        apply_drop_prob = partial(reset_dropblock, args.warmup_epochs*len(train_loader),
                                  nr_iters, 0.0, args.dropblock_prob)
        model.apply(apply_drop_prob)

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
158
159
    print(model)
    # criterion and optimizer
Hang Zhang's avatar
Hang Zhang committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    if args.mixup > 0:
        train_loader = MixUpWrapper(args.mixup, 1000, train_loader,
                                    list(range(torch.cuda.device_count())))
        criterion = NLLMultiLabelSmooth(args.label_smoothing)
    elif args.label_smoothing > 0.0:
        criterion = LabelSmoothing(args.label_smoothing)
    else:
        criterion = nn.CrossEntropyLoss()

    if args.no_bn_wd:
        parameters = model.named_parameters()
        param_dict = {}
        for k, v in parameters:
            param_dict[k] = v
        bn_params = [v for n, v in param_dict.items() if ('bn' in n or 'bias' in n)]
        rest_params = [v for n, v in param_dict.items() if not ('bn' in n or 'bias' in n)]
        print(" Weight decay NOT applied to BN parameters ")
        print(f'len(parameters): {len(list(model.parameters()))} = {len(bn_params)} + {len(rest_params)}')
        optimizer = torch.optim.SGD([{'params': bn_params, 'weight_decay': 0 },
                                     {'params': rest_params, 'weight_decay': args.weight_decay}],
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
188
189
    if args.cuda:
        model.cuda()
Hang Zhang's avatar
Hang Zhang committed
190
        criterion.cuda()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
191
        # Please use CUDA_VISIBLE_DEVICES to control the number of gpus
Hang Zhang's avatar
Hang Zhang committed
192
        model = nn.DataParallel(model)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
193
194
195
196
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
Hang Zhang's avatar
Hang Zhang committed
197
            args.start_epoch = checkpoint['epoch'] + 1 if args.start_epoch == 1 else args.start_epoch
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
198
            best_pred = checkpoint['best_pred']
Hang Zhang's avatar
Hang Zhang committed
199
200
201
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
202
203
204
205
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.resume, checkpoint['epoch']))
        else:
Zhang's avatar
v0.2.0  
Zhang committed
206
            raise RuntimeError ("=> no resume checkpoint found at '{}'".\
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
207
                format(args.resume))
Hang Zhang's avatar
Hang Zhang committed
208
209
210
211
212
213
    scheduler = LR_Scheduler(args.lr_scheduler,
                             base_lr=args.lr,
                             num_epochs=args.epochs,
                             iters_per_epoch=len(train_loader),
                             warmup_epochs=args.warmup_epochs,
                             lr_step=args.lr_step)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
214
215
    def train(epoch):
        model.train()
Hang Zhang's avatar
Hang Zhang committed
216
217
218
        losses = AverageMeter()
        top1 = AverageMeter()
        global best_pred, acclist_train
Hang Zhang's avatar
Hang Zhang committed
219
220
221
        tbar = tqdm(train_loader, desc='\r')
        for batch_idx, (data, target) in enumerate(tbar):
            scheduler(optimizer, batch_idx, epoch, best_pred)
Hang Zhang's avatar
Hang Zhang committed
222
            #criterion.update(batch_idx, epoch)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
223
224
225
226
227
228
229
230
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

Hang Zhang's avatar
Hang Zhang committed
231
232
233
234
            acc1 = accuracy(output, target, topk=(1,))
            top1.update(acc1[0], data.size(0))
            losses.update(loss.item(), data.size(0))
            tbar.set_description('\rLoss: %.3f | Top1: %.3f'%(losses.avg, top1.avg))
Hang Zhang's avatar
Hang Zhang committed
235

Hang Zhang's avatar
Hang Zhang committed
236
        acclist_train += [top1.avg]
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
237

Hang Zhang's avatar
Hang Zhang committed
238
    def validate(epoch):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
239
        model.eval()
Hang Zhang's avatar
Hang Zhang committed
240
241
242
        top1 = AverageMeter()
        top5 = AverageMeter()
        global best_pred, acclist_train, acclist_val
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
243
        is_best = False
Hang Zhang's avatar
Hang Zhang committed
244
        tbar = tqdm(val_loader, desc='\r')
Hang Zhang's avatar
Hang Zhang committed
245
        for batch_idx, (data, target) in enumerate(tbar):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
246
247
            if args.cuda:
                data, target = data.cuda(), target.cuda()
Zhang's avatar
v0.2.0  
Zhang committed
248
249
            with torch.no_grad():
                output = model(data)
Hang Zhang's avatar
Hang Zhang committed
250
251
252
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                top1.update(acc1[0], data.size(0))
                top5.update(acc5[0], data.size(0))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
253

Hang Zhang's avatar
Hang Zhang committed
254
            tbar.set_description('Top1: %.3f | Top5: %.3f'%(top1.avg, top5.avg))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
255
256

        if args.eval:
Hang Zhang's avatar
Hang Zhang committed
257
            print('Top1 Acc: %.3f | Top5 Acc: %.3f '%(top1.avg, top5.avg))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
258
259
            return
        # save checkpoint
Hang Zhang's avatar
Hang Zhang committed
260
261
262
        acclist_val += [top1.avg]
        if top1.avg > best_pred:
            best_pred = top1.avg 
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
263
            is_best = True
Hang Zhang's avatar
Hang Zhang committed
264
        encoding.utils.save_checkpoint({
Hang Zhang's avatar
Hang Zhang committed
265
            'args': args,
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
266
            'epoch': epoch,
Hang Zhang's avatar
Hang Zhang committed
267
            'state_dict': model.module.state_dict(),
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
268
269
            'optimizer': optimizer.state_dict(),
            'best_pred': best_pred,
Hang Zhang's avatar
Hang Zhang committed
270
271
            'acclist_train':acclist_train,
            'acclist_val':acclist_val,
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
272
273
            }, args=args, is_best=is_best)

Hang Zhang's avatar
Hang Zhang committed
274
275
276
277
    if args.export:
        torch.save(model.module.state_dict(), args.export + '.pth')
        return

Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
278
    if args.eval:
Hang Zhang's avatar
Hang Zhang committed
279
        validate(args.start_epoch)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
280
281
        return

Hang Zhang's avatar
Hang Zhang committed
282
    for epoch in range(args.start_epoch, args.epochs):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
283
        train(epoch)
Hang Zhang's avatar
Hang Zhang committed
284
285
        validate(epoch)

Hang Zhang's avatar
Hang Zhang committed
286
    validate(epoch)
Hang Zhang's avatar
Hang Zhang committed
287
288

if __name__ == "__main__":
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
289
    main()