test.py 9.03 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
###########################################################################
# Created by: Hang Zhang 
# Email: zhang.hang@rutgers.edu 
# Copyright (c) 2017
###########################################################################

import os
Hang Zhang's avatar
Hang Zhang committed
8
import argparse
Hang Zhang's avatar
Hang Zhang committed
9
10
11
12
13
14
15
16
17
import numpy as np
from tqdm import tqdm

import torch
from torch.utils import data
import torchvision.transforms as transform
from torch.nn.parallel.scatter_gather import gather

import encoding.utils as utils
Hang Zhang's avatar
Hang Zhang committed
18
from encoding.nn import SegmentationLosses, SyncBatchNorm
Hang Zhang's avatar
Hang Zhang committed
19
from encoding.parallel import DataParallelModel, DataParallelCriterion
Hang Zhang's avatar
Hang Zhang committed
20
from encoding.datasets import get_dataset, test_batchify_fn
Hang Zhang's avatar
Hang Zhang committed
21
22
from encoding.models import get_model, get_segmentation_model, MultiEvalModule

Hang Zhang's avatar
Hang Zhang committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

class Options():
    def __init__(self):
        parser = argparse.ArgumentParser(description='PyTorch Segmentation')
        # model and dataset 
        parser.add_argument('--model', type=str, default='encnet',
                            help='model name (default: encnet)')
        parser.add_argument('--backbone', type=str, default='resnet50',
                            help='backbone name (default: resnet50)')
        parser.add_argument('--dataset', type=str, default='ade20k',
                            help='dataset name (default: pascal12)')
        parser.add_argument('--workers', type=int, default=16,
                            metavar='N', help='dataloader threads')
        parser.add_argument('--base-size', type=int, default=520,
                            help='base image size')
        parser.add_argument('--crop-size', type=int, default=480,
                            help='crop image size')
        parser.add_argument('--train-split', type=str, default='train',
                            help='dataset train split (default: train)')
        # training hyper params
        parser.add_argument('--aux', action='store_true', default= False,
                            help='Auxilary Loss')
        parser.add_argument('--se-loss', action='store_true', default= False,
                            help='Semantic Encoding Loss SE-loss')
        parser.add_argument('--se-weight', type=float, default=0.2,
                            help='SE-loss weight (default: 0.2)')
        parser.add_argument('--batch-size', type=int, default=16,
                            metavar='N', help='input batch size for \
                            training (default: auto)')
        parser.add_argument('--test-batch-size', type=int, default=16,
                            metavar='N', help='input batch size for \
                            testing (default: same as batch size)')
        # cuda, seed and logging
        parser.add_argument('--no-cuda', action='store_true', default=
                            False, help='disables CUDA training')
        parser.add_argument('--seed', type=int, default=1, metavar='S',
                            help='random seed (default: 1)')
        # checking point
        parser.add_argument('--resume', type=str, default=None,
                            help='put the path to resuming file if needed')
        parser.add_argument('--verify', type=str, default=None,
                            help='put the path to resuming file if needed')
        parser.add_argument('--model-zoo', type=str, default=None,
                            help='evaluating on model zoo model')
        # evaluation option
        parser.add_argument('--eval', action='store_true', default= False,
                            help='evaluating mIoU')
        parser.add_argument('--export', type=str, default=None,
                            help='put the path to resuming file if needed')
        parser.add_argument('--acc-bn', action='store_true', default= False,
                            help='Re-accumulate BN statistics')
        parser.add_argument('--test-val', action='store_true', default= False,
                            help='generate masks on val set')
        parser.add_argument('--no-val', action='store_true', default= False,
                            help='skip validation during training')
        # test option
        parser.add_argument('--test-folder', type=str, default=None,
                            help='path to test image folder')
        # the parser
        self.parser = parser

    def parse(self):
        args = self.parser.parse_args()
        args.cuda = not args.no_cuda and torch.cuda.is_available()
        print(args)
        return args

@torch.no_grad()
def reset_bn_statistics(m):
    if isinstance(m, torch.nn.BatchNorm2d):
        #print(m)
        m.momentum = 0.0
        m.reset_running_stats()
Hang Zhang's avatar
Hang Zhang committed
96
97
98
99
100
101
102
103
104
105
106
107

def test(args):
    # output folder
    outdir = 'outdir'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])])
    # dataset
    if args.eval:
Hang Zhang's avatar
Hang Zhang committed
108
109
        testset = get_dataset(args.dataset, split='val', mode='testval',
                              transform=input_transform)
Hang Zhang's avatar
Hang Zhang committed
110
    elif args.test_val:
Hang Zhang's avatar
Hang Zhang committed
111
112
        testset = get_dataset(args.dataset, split='val', mode='test',
                              transform=input_transform)
Hang Zhang's avatar
Hang Zhang committed
113
    else:
Hang Zhang's avatar
Hang Zhang committed
114
115
        testset = get_dataset(args.dataset, split='test', mode='test',
                              transform=input_transform)
Hang Zhang's avatar
Hang Zhang committed
116
    # dataloader
Hang Zhang's avatar
Hang Zhang committed
117
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
Hang Zhang's avatar
Hang Zhang committed
118
        if args.cuda else {}
119
    test_data = data.DataLoader(testset, batch_size=args.test_batch_size,
Hang Zhang's avatar
Hang Zhang committed
120
                                drop_last=False, shuffle=False,
Hang Zhang's avatar
Hang Zhang committed
121
                                collate_fn=test_batchify_fn, **loader_kwargs)
Hang Zhang's avatar
Hang Zhang committed
122
123
124
    # model
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
Hang Zhang's avatar
Hang Zhang committed
125
126
        #model.base_size = args.base_size
        #model.crop_size = args.crop_size
Hang Zhang's avatar
Hang Zhang committed
127
128
    else:
        model = get_segmentation_model(args.model, dataset=args.dataset,
Hang Zhang's avatar
Hang Zhang committed
129
130
131
                                       backbone=args.backbone, aux = args.aux,
                                       se_loss=args.se_loss,
                                       norm_layer=torch.nn.BatchNorm2d if args.acc_bn else SyncBatchNorm,
Hang Zhang's avatar
Hang Zhang committed
132
                                       base_size=args.base_size, crop_size=args.crop_size)
Hang Zhang's avatar
Hang Zhang committed
133
        # resuming checkpoint
Hang Zhang's avatar
Hang Zhang committed
134
135
136
137
138
139
140
141
142
143
        if args.verify is not None and os.path.isfile(args.verify):
            print("=> loading checkpoint '{}'".format(args.verify))
            model.load_state_dict(torch.load(args.verify))
        elif args.resume is not None and os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            # strict=False, so that it is compatible with old pytorch saved models
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            raise RuntimeError ("=> no checkpoint found")
Hang Zhang's avatar
Hang Zhang committed
144
145

    print(model)
Hang Zhang's avatar
Hang Zhang committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    # accumulate bn statistics
    if args.acc_bn:
        print('Reseting BN statistics')
        model.apply(reset_bn_statistics)
        data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
                       'crop_size': args.crop_size}
        trainset = get_dataset(args.dataset, split=args.train_split, mode='train', **data_kwargs)
        trainloader = data.DataLoader(trainset, batch_size=args.batch_size,
                                      drop_last=True, shuffle=True, **loader_kwargs)
        tbar = tqdm(trainloader)
        model.train()
        model.cuda()
        for i, (image, dst) in enumerate(tbar):
            image = image.cuda()
            with torch.no_grad():
                outputs = model(image)
            if i > 1000: break

    if args.export:
        torch.save(model.state_dict(), args.export + '.pth')
        return

Hang Zhang's avatar
Hang Zhang committed
168
    scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
Hang Zhang's avatar
Hang Zhang committed
169
            [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]#, 2.0
Hang Zhang's avatar
Hang Zhang committed
170
    evaluator = MultiEvalModule(model, testset.num_class, scales=scales).cuda()
Hang Zhang's avatar
Hang Zhang committed
171
    evaluator.eval()
Hang Zhang's avatar
Hang Zhang committed
172
    metric = utils.SegmentationMetric(testset.num_class)
Hang Zhang's avatar
Hang Zhang committed
173
174

    tbar = tqdm(test_data)
Hang Zhang's avatar
Hang Zhang committed
175
176
177
178
179
180
181
    for i, (image, dst) in enumerate(tbar):
        if args.eval:
            with torch.no_grad():
                predicts = evaluator.parallel_forward(image)
                metric.update(dst, predicts)
                pixAcc, mIoU = metric.get()
                tbar.set_description( 'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
Hang Zhang's avatar
Hang Zhang committed
182
        else:
Hang Zhang's avatar
Hang Zhang committed
183
184
185
186
187
            with torch.no_grad():
                outputs = evaluator.parallel_forward(image)
                predicts = [testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
                            for output in outputs]
            for predict, impath in zip(predicts, dst):
Hang Zhang's avatar
Hang Zhang committed
188
189
190
191
                mask = utils.get_mask_pallete(predict, args.dataset)
                outname = os.path.splitext(impath)[0] + '.png'
                mask.save(os.path.join(outdir, outname))

Hang Zhang's avatar
Hang Zhang committed
192
193
    print( 'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))

Hang Zhang's avatar
Hang Zhang committed
194
195
196
197
198
if __name__ == "__main__":
    args = Options().parse()
    torch.manual_seed(args.seed)
    args.test_batch_size = torch.cuda.device_count()
    test(args)