test.py 4.12 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
###########################################################################
# Created by: Hang Zhang 
# Email: zhang.hang@rutgers.edu 
# Copyright (c) 2017
###########################################################################

import os
import numpy as np
from tqdm import tqdm

import torch
from torch.utils import data
import torchvision.transforms as transform
from torch.nn.parallel.scatter_gather import gather

import encoding.utils as utils
Hang Zhang's avatar
Hang Zhang committed
17
from encoding.nn import SegmentationLosses, SyncBatchNorm
Hang Zhang's avatar
Hang Zhang committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from encoding.parallel import DataParallelModel, DataParallelCriterion
from encoding.datasets import get_segmentation_dataset, test_batchify_fn
from encoding.models import get_model, get_segmentation_model, MultiEvalModule

from option import Options

def test(args):
    # output folder
    outdir = 'outdir'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])])
    # dataset
    if args.eval:
        testset = get_segmentation_dataset(args.dataset, split='val', mode='testval',
                                           transform=input_transform)
Hang Zhang's avatar
Hang Zhang committed
37
38
39
    elif args.test_val:
        testset = get_segmentation_dataset(args.dataset, split='val', mode='test',
                                           transform=input_transform)
Hang Zhang's avatar
Hang Zhang committed
40
41
42
43
    else:
        testset = get_segmentation_dataset(args.dataset, split='test', mode='test',
                                           transform=input_transform)
    # dataloader
Hang Zhang's avatar
Hang Zhang committed
44
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
Hang Zhang's avatar
Hang Zhang committed
45
        if args.cuda else {}
46
    test_data = data.DataLoader(testset, batch_size=args.test_batch_size,
Hang Zhang's avatar
Hang Zhang committed
47
                                drop_last=False, shuffle=False,
Hang Zhang's avatar
Hang Zhang committed
48
                                collate_fn=test_batchify_fn, **loader_kwargs)
Hang Zhang's avatar
Hang Zhang committed
49
50
51
    # model
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
Hang Zhang's avatar
Hang Zhang committed
52
53
        #model.base_size = args.base_size
        #model.crop_size = args.crop_size
Hang Zhang's avatar
Hang Zhang committed
54
55
56
    else:
        model = get_segmentation_model(args.model, dataset=args.dataset,
                                       backbone = args.backbone, aux = args.aux,
Hang Zhang's avatar
Hang Zhang committed
57
                                       se_loss = args.se_loss, norm_layer = SyncBatchNorm,
Hang Zhang's avatar
Hang Zhang committed
58
                                       base_size=args.base_size, crop_size=args.crop_size)
Hang Zhang's avatar
Hang Zhang committed
59
60
61
62
63
        # resuming checkpoint
        if args.resume is None or not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
        checkpoint = torch.load(args.resume)
        # strict=False, so that it is compatible with old pytorch saved models
Hang Zhang's avatar
Hang Zhang committed
64
        model.load_state_dict(checkpoint['state_dict'])
Hang Zhang's avatar
Hang Zhang committed
65
66
67
        print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))

    print(model)
Hang Zhang's avatar
Hang Zhang committed
68
69
    scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
        [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
Hang Zhang's avatar
Hang Zhang committed
70
    evaluator = MultiEvalModule(model, testset.num_class, scales=scales).cuda()
Hang Zhang's avatar
Hang Zhang committed
71
    evaluator.eval()
Hang Zhang's avatar
Hang Zhang committed
72
    metric = utils.SegmentationMetric(testset.num_class)
Hang Zhang's avatar
Hang Zhang committed
73
74

    tbar = tqdm(test_data)
Hang Zhang's avatar
Hang Zhang committed
75
76
77
78
79
80
81
    for i, (image, dst) in enumerate(tbar):
        if args.eval:
            with torch.no_grad():
                predicts = evaluator.parallel_forward(image)
                metric.update(dst, predicts)
                pixAcc, mIoU = metric.get()
                tbar.set_description( 'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
Hang Zhang's avatar
Hang Zhang committed
82
        else:
Hang Zhang's avatar
Hang Zhang committed
83
84
85
86
87
            with torch.no_grad():
                outputs = evaluator.parallel_forward(image)
                predicts = [testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
                            for output in outputs]
            for predict, impath in zip(predicts, dst):
Hang Zhang's avatar
Hang Zhang committed
88
89
90
91
92
93
94
95
96
                mask = utils.get_mask_pallete(predict, args.dataset)
                outname = os.path.splitext(impath)[0] + '.png'
                mask.save(os.path.join(outdir, outname))

if __name__ == "__main__":
    args = Options().parse()
    torch.manual_seed(args.seed)
    args.test_batch_size = torch.cuda.device_count()
    test(args)