test.py 9.02 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
###########################################################################
# Created by: Hang Zhang 
# Email: zhang.hang@rutgers.edu 
# Copyright (c) 2017
###########################################################################

import os
Hang Zhang's avatar
Hang Zhang committed
8
import argparse
Hang Zhang's avatar
Hang Zhang committed
9
10
11
12
13
14
15
16
17
import numpy as np
from tqdm import tqdm

import torch
from torch.utils import data
import torchvision.transforms as transform
from torch.nn.parallel.scatter_gather import gather

import encoding.utils as utils
Hang Zhang's avatar
Hang Zhang committed
18
from encoding.nn import SegmentationLosses, SyncBatchNorm
Hang Zhang's avatar
Hang Zhang committed
19
from encoding.parallel import DataParallelModel, DataParallelCriterion
Hang Zhang's avatar
Hang Zhang committed
20
from encoding.datasets import get_dataset, test_batchify_fn
Hang Zhang's avatar
Hang Zhang committed
21
22
from encoding.models import get_model, get_segmentation_model, MultiEvalModule

Hang Zhang's avatar
Hang Zhang committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

class Options():
    def __init__(self):
        parser = argparse.ArgumentParser(description='PyTorch Segmentation')
        # model and dataset 
        parser.add_argument('--model', type=str, default='encnet',
                            help='model name (default: encnet)')
        parser.add_argument('--backbone', type=str, default='resnet50',
                            help='backbone name (default: resnet50)')
        parser.add_argument('--dataset', type=str, default='ade20k',
                            help='dataset name (default: pascal12)')
        parser.add_argument('--workers', type=int, default=16,
                            metavar='N', help='dataloader threads')
        parser.add_argument('--base-size', type=int, default=520,
                            help='base image size')
        parser.add_argument('--crop-size', type=int, default=480,
                            help='crop image size')
        parser.add_argument('--train-split', type=str, default='train',
                            help='dataset train split (default: train)')
        # training hyper params
        parser.add_argument('--aux', action='store_true', default= False,
                            help='Auxilary Loss')
        parser.add_argument('--se-loss', action='store_true', default= False,
                            help='Semantic Encoding Loss SE-loss')
        parser.add_argument('--se-weight', type=float, default=0.2,
                            help='SE-loss weight (default: 0.2)')
        parser.add_argument('--batch-size', type=int, default=16,
                            metavar='N', help='input batch size for \
                            training (default: auto)')
        parser.add_argument('--test-batch-size', type=int, default=16,
                            metavar='N', help='input batch size for \
                            testing (default: same as batch size)')
        # cuda, seed and logging
        parser.add_argument('--no-cuda', action='store_true', default=
                            False, help='disables CUDA training')
        parser.add_argument('--seed', type=int, default=1, metavar='S',
                            help='random seed (default: 1)')
        # checking point
        parser.add_argument('--resume', type=str, default=None,
                            help='put the path to resuming file if needed')
        parser.add_argument('--verify', type=str, default=None,
                            help='put the path to resuming file if needed')
        parser.add_argument('--model-zoo', type=str, default=None,
                            help='evaluating on model zoo model')
        # evaluation option
        parser.add_argument('--eval', action='store_true', default= False,
                            help='evaluating mIoU')
        parser.add_argument('--export', type=str, default=None,
                            help='put the path to resuming file if needed')
        parser.add_argument('--acc-bn', action='store_true', default= False,
                            help='Re-accumulate BN statistics')
        parser.add_argument('--test-val', action='store_true', default= False,
                            help='generate masks on val set')
        parser.add_argument('--no-val', action='store_true', default= False,
                            help='skip validation during training')
        # test option
        parser.add_argument('--test-folder', type=str, default=None,
                            help='path to test image folder')
        # the parser
        self.parser = parser

    def parse(self):
        args = self.parser.parse_args()
        args.cuda = not args.no_cuda and torch.cuda.is_available()
        print(args)
        return args

Hang Zhang's avatar
Hang Zhang committed
90
91
92
93
94
95
96
97
98
99
100
def test(args):
    # output folder
    outdir = 'outdir'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])])
    # dataset
    if args.eval:
Hang Zhang's avatar
Hang Zhang committed
101
102
        testset = get_dataset(args.dataset, split='val', mode='testval',
                              transform=input_transform)
Hang Zhang's avatar
Hang Zhang committed
103
    elif args.test_val:
Hang Zhang's avatar
Hang Zhang committed
104
105
        testset = get_dataset(args.dataset, split='val', mode='test',
                              transform=input_transform)
Hang Zhang's avatar
Hang Zhang committed
106
    else:
Hang Zhang's avatar
Hang Zhang committed
107
108
        testset = get_dataset(args.dataset, split='test', mode='test',
                              transform=input_transform)
Hang Zhang's avatar
Hang Zhang committed
109
    # dataloader
Hang Zhang's avatar
Hang Zhang committed
110
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
Hang Zhang's avatar
Hang Zhang committed
111
        if args.cuda else {}
112
    test_data = data.DataLoader(testset, batch_size=args.test_batch_size,
Hang Zhang's avatar
Hang Zhang committed
113
                                drop_last=False, shuffle=False,
Hang Zhang's avatar
Hang Zhang committed
114
                                collate_fn=test_batchify_fn, **loader_kwargs)
Hang Zhang's avatar
Hang Zhang committed
115
    # model
116
    pretrained = args.resume is None and args.verify is None
Hang Zhang's avatar
Hang Zhang committed
117
    if args.model_zoo is not None:
118
119
120
        model = get_model(args.model_zoo, pretrained=pretrained)
        model.base_size = args.base_size
        model.crop_size = args.crop_size
Hang Zhang's avatar
Hang Zhang committed
121
122
    else:
        model = get_segmentation_model(args.model, dataset=args.dataset,
Hang Zhang's avatar
Hang Zhang committed
123
124
125
                                       backbone=args.backbone, aux = args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=torch.nn.BatchNorm2d if args.acc_bn else SyncBatchNorm,
Hang Zhang's avatar
Hang Zhang committed
126
                                       base_size=args.base_size, crop_size=args.crop_size)
127
128
129
130
131
132
133
134
135
136
137
138

    # resuming checkpoint
    if args.verify is not None and os.path.isfile(args.verify):
        print("=> loading checkpoint '{}'".format(args.verify))
        model.load_state_dict(torch.load(args.verify))
    elif args.resume is not None and os.path.isfile(args.resume):
        checkpoint = torch.load(args.resume)
        # strict=False, so that it is compatible with old pytorch saved models
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
    elif not pretrained:
        raise RuntimeError ("=> no checkpoint found")
Hang Zhang's avatar
Hang Zhang committed
139
140

    print(model)
Hang Zhang's avatar
Hang Zhang committed
141
    if args.acc_bn:
142
        from encoding.utils.precise_bn import update_bn_stats
Hang Zhang's avatar
Hang Zhang committed
143
144
145
        data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
                       'crop_size': args.crop_size}
        trainset = get_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs)
146
        trainloader = data.DataLoader(ReturnFirstClosure(trainset), batch_size=args.batch_size,
Hang Zhang's avatar
Hang Zhang committed
147
                                      drop_last=True, shuffle=True, **loader_kwargs)
148
149
        print('Reseting BN statistics')
        #model.apply(reset_bn_statistics)
Hang Zhang's avatar
Hang Zhang committed
150
        model.cuda()
151
        update_bn_stats(model, trainloader)
Hang Zhang's avatar
Hang Zhang committed
152
153
154
155
156

    if args.export:
        torch.save(model.state_dict(), args.export + '.pth')
        return

Hang Zhang's avatar
Hang Zhang committed
157
    scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
Hang Zhang's avatar
Hang Zhang committed
158
            [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]#, 2.0
Hang Zhang's avatar
Hang Zhang committed
159
    evaluator = MultiEvalModule(model, testset.num_class, scales=scales).cuda()
Hang Zhang's avatar
Hang Zhang committed
160
    evaluator.eval()
Hang Zhang's avatar
Hang Zhang committed
161
    metric = utils.SegmentationMetric(testset.num_class)
Hang Zhang's avatar
Hang Zhang committed
162
163

    tbar = tqdm(test_data)
Hang Zhang's avatar
Hang Zhang committed
164
165
166
167
168
169
170
    for i, (image, dst) in enumerate(tbar):
        if args.eval:
            with torch.no_grad():
                predicts = evaluator.parallel_forward(image)
                metric.update(dst, predicts)
                pixAcc, mIoU = metric.get()
                tbar.set_description( 'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
Hang Zhang's avatar
Hang Zhang committed
171
        else:
Hang Zhang's avatar
Hang Zhang committed
172
173
174
175
176
            with torch.no_grad():
                outputs = evaluator.parallel_forward(image)
                predicts = [testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
                            for output in outputs]
            for predict, impath in zip(predicts, dst):
Hang Zhang's avatar
Hang Zhang committed
177
178
179
180
                mask = utils.get_mask_pallete(predict, args.dataset)
                outname = os.path.splitext(impath)[0] + '.png'
                mask.save(os.path.join(outdir, outname))

181
182
    if args.eval:
        print( 'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
Hang Zhang's avatar
Hang Zhang committed
183

184
185
186
187
188
189
190
191
192
193
194
class ReturnFirstClosure(object):
    def __init__(self, data):
        self._data = data

    def __len__(self):
        return len(self._data)

    def __getitem__(self, idx):
        outputs = self._data[idx]
        return outputs[0]

Hang Zhang's avatar
Hang Zhang committed
195
196
197
198
199
if __name__ == "__main__":
    args = Options().parse()
    torch.manual_seed(args.seed)
    args.test_batch_size = torch.cuda.device_count()
    test(args)