main.py 6.38 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree 
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

from __future__ import print_function
Hang Zhang's avatar
Hang Zhang committed
12
import os
Hang Zhang's avatar
Hang Zhang committed
13
from tqdm import tqdm
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
14

Hang Zhang's avatar
Hang Zhang committed
15
16
17
import torch
import torch.nn as nn

Hang Zhang's avatar
Hang Zhang committed
18
import encoding
Hang Zhang's avatar
Hang Zhang committed
19
from option import Options
Hang Zhang's avatar
Hang Zhang committed
20

Hang Zhang's avatar
Hang Zhang committed
21
# global variable
Hang Zhang's avatar
Hang Zhang committed
22
23
24
best_pred = 0.0
acclist_train = []
acclist_val = []
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
25

Hang Zhang's avatar
Hang Zhang committed
26
def main():
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
27
    # init the args
Hang Zhang's avatar
Hang Zhang committed
28
    global best_pred, acclist_train, acclist_val
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
29
30
    args = Options().parse()
    args.cuda = not args.no_cuda and torch.cuda.is_available()
Hang Zhang's avatar
Hang Zhang committed
31
    print(args)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
32
33
34
35
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
    # init dataloader
Hang Zhang's avatar
Hang Zhang committed
36
37
38
39
40
41
42
43
44
45
46
47
48
    transform_train, transform_val = encoding.transforms.get_transform(args.dataset)
    trainset = encoding.datasets.get_dataset(args.dataset, root=os.path.expanduser('~/.encoding/data'),
                                             transform=transform_train, train=True, download=True)
    valset = encoding.datasets.get_dataset(args.dataset, root=os.path.expanduser('~/.encoding/data'),
                                           transform=transform_val, train=False, download=True)
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        valset, batch_size=args.test_batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)
    
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
49
    # init the model
Hang Zhang's avatar
Hang Zhang committed
50
    model = encoding.models.get_model(args.model, pretrained=args.pretrained)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
51
52
53
    print(model)
    # criterion and optimizer
    criterion = nn.CrossEntropyLoss()
Hang Zhang's avatar
Hang Zhang committed
54
55
56
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
57
58
    if args.cuda:
        model.cuda()
Hang Zhang's avatar
Hang Zhang committed
59
        criterion.cuda()
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
60
        # Please use CUDA_VISIBLE_DEVICES to control the number of gpus
Hang Zhang's avatar
Hang Zhang committed
61
        model = nn.DataParallel(model)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
62
63
64
65
66
67
68
    # check point
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch'] +1
            best_pred = checkpoint['best_pred']
Hang Zhang's avatar
Hang Zhang committed
69
70
71
            acclist_train = checkpoint['acclist_train']
            acclist_val = checkpoint['acclist_val']
            model.module.load_state_dict(checkpoint['state_dict'])
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
72
73
74
75
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.resume, checkpoint['epoch']))
        else:
Zhang's avatar
v0.2.0  
Zhang committed
76
            raise RuntimeError ("=> no resume checkpoint found at '{}'".\
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
77
                format(args.resume))
Hang Zhang's avatar
Hang Zhang committed
78
79
    scheduler = encoding.utils.LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                            len(train_loader), args.lr_step)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
80
81
    def train(epoch):
        model.train()
Hang Zhang's avatar
Hang Zhang committed
82
83
84
        losses = AverageMeter()
        top1 = AverageMeter()
        global best_pred, acclist_train
Hang Zhang's avatar
Hang Zhang committed
85
86
87
        tbar = tqdm(train_loader, desc='\r')
        for batch_idx, (data, target) in enumerate(tbar):
            scheduler(optimizer, batch_idx, epoch, best_pred)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
88
89
90
91
92
93
94
95
            if args.cuda:
                data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

Hang Zhang's avatar
Hang Zhang committed
96
97
98
99
            acc1 = accuracy(output, target, topk=(1,))
            top1.update(acc1[0], data.size(0))
            losses.update(loss.item(), data.size(0))
            tbar.set_description('\rLoss: %.3f | Top1: %.3f'%(losses.avg, top1.avg))
Hang Zhang's avatar
Hang Zhang committed
100

Hang Zhang's avatar
Hang Zhang committed
101
        acclist_train += [top1.avg]
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
102

Hang Zhang's avatar
Hang Zhang committed
103
    def validate(epoch):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
104
        model.eval()
Hang Zhang's avatar
Hang Zhang committed
105
106
107
        top1 = AverageMeter()
        top5 = AverageMeter()
        global best_pred, acclist_train, acclist_val
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
108
        is_best = False
Hang Zhang's avatar
Hang Zhang committed
109
        tbar = tqdm(val_loader, desc='\r')
Hang Zhang's avatar
Hang Zhang committed
110
        for batch_idx, (data, target) in enumerate(tbar):
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
111
112
            if args.cuda:
                data, target = data.cuda(), target.cuda()
Zhang's avatar
v0.2.0  
Zhang committed
113
114
            with torch.no_grad():
                output = model(data)
Hang Zhang's avatar
Hang Zhang committed
115
116
117
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                top1.update(acc1[0], data.size(0))
                top5.update(acc5[0], data.size(0))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
118

Hang Zhang's avatar
Hang Zhang committed
119
            tbar.set_description('Top1: %.3f | Top5: %.3f'%(top1.avg, top5.avg))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
120
121

        if args.eval:
Hang Zhang's avatar
Hang Zhang committed
122
            print('Top1 Acc: %.3f | Top5 Acc: %.3f '%(top1.avg, top5.avg))
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
123
124
            return
        # save checkpoint
Hang Zhang's avatar
Hang Zhang committed
125
126
127
        acclist_val += [top1.avg]
        if top1.avg > best_pred:
            best_pred = top1.avg 
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
128
            is_best = True
Hang Zhang's avatar
Hang Zhang committed
129
        encoding.utils.save_checkpoint({
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
130
            'epoch': epoch,
Hang Zhang's avatar
Hang Zhang committed
131
            'state_dict': model.module.state_dict(),
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
132
133
            'optimizer': optimizer.state_dict(),
            'best_pred': best_pred,
Hang Zhang's avatar
Hang Zhang committed
134
135
            'acclist_train':acclist_train,
            'acclist_val':acclist_val,
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
136
137
138
            }, args=args, is_best=is_best)

    if args.eval:
Hang Zhang's avatar
Hang Zhang committed
139
        validate(args.start_epoch)
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
140
141
142
143
        return

    for epoch in range(args.start_epoch, args.epochs + 1):
        train(epoch)
Hang Zhang's avatar
Hang Zhang committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        validate(epoch)

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
Hang Zhang's avatar
Hang Zhang committed
179
180

if __name__ == "__main__":
Hang Zhang's avatar
v1.0.1  
Hang Zhang committed
181
    main()