test.py 3.87 KB
Newer Older
Hang Zhang's avatar
Hang Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
###########################################################################
# Created by: Hang Zhang 
# Email: zhang.hang@rutgers.edu 
# Copyright (c) 2017
###########################################################################

import os
import numpy as np
from tqdm import tqdm

import torch
from torch.utils import data
import torchvision.transforms as transform
from torch.nn.parallel.scatter_gather import gather

import encoding.utils as utils
from encoding.nn import SegmentationLosses, BatchNorm2d
from encoding.parallel import DataParallelModel, DataParallelCriterion
from encoding.datasets import get_segmentation_dataset, test_batchify_fn
from encoding.models import get_model, get_segmentation_model, MultiEvalModule

from option import Options

def test(args):
    # output folder
    outdir = 'outdir'
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    # data transforms
    input_transform = transform.Compose([
        transform.ToTensor(),
        transform.Normalize([.485, .456, .406], [.229, .224, .225])])
    # dataset
    if args.eval:
        testset = get_segmentation_dataset(args.dataset, split='val', mode='testval',
                                           transform=input_transform)
    else:
        testset = get_segmentation_dataset(args.dataset, split='test', mode='test',
                                           transform=input_transform)
    # dataloader
Hang Zhang's avatar
Hang Zhang committed
41
    loader_kwargs = {'num_workers': args.workers, 'pin_memory': True} \
Hang Zhang's avatar
Hang Zhang committed
42
        if args.cuda else {}
43
    test_data = data.DataLoader(testset, batch_size=args.test_batch_size,
Hang Zhang's avatar
Hang Zhang committed
44
                                drop_last=False, shuffle=False,
Hang Zhang's avatar
Hang Zhang committed
45
                                collate_fn=test_batchify_fn, **loader_kwargs)
Hang Zhang's avatar
Hang Zhang committed
46
47
48
49
50
51
    # model
    if args.model_zoo is not None:
        model = get_model(args.model_zoo, pretrained=True)
    else:
        model = get_segmentation_model(args.model, dataset=args.dataset,
                                       backbone = args.backbone, aux = args.aux,
Hang Zhang's avatar
Hang Zhang committed
52
53
                                       se_loss = args.se_loss, norm_layer = BatchNorm2d,
                                       base_size=args.base_size, crop_size=args.crop_size)
Hang Zhang's avatar
Hang Zhang committed
54
55
56
57
58
        # resuming checkpoint
        if args.resume is None or not os.path.isfile(args.resume):
            raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
        checkpoint = torch.load(args.resume)
        # strict=False, so that it is compatible with old pytorch saved models
Hang Zhang's avatar
Hang Zhang committed
59
        model.load_state_dict(checkpoint['state_dict'])
Hang Zhang's avatar
Hang Zhang committed
60
61
62
        print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))

    print(model)
Hang Zhang's avatar
Hang Zhang committed
63
64
65
    scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] if args.dataset == 'citys' else \
        [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
    evaluator = MultiEvalModule(model, testset.num_class, scales=scales).cuda()
Hang Zhang's avatar
Hang Zhang committed
66
    evaluator.eval()
Hang Zhang's avatar
Hang Zhang committed
67
    metric = utils.SegmentationMetric(testset.num_class)
Hang Zhang's avatar
Hang Zhang committed
68
69

    tbar = tqdm(test_data)
Hang Zhang's avatar
Hang Zhang committed
70
71
72
73
74
75
76
    for i, (image, dst) in enumerate(tbar):
        if args.eval:
            with torch.no_grad():
                predicts = evaluator.parallel_forward(image)
                metric.update(dst, predicts)
                pixAcc, mIoU = metric.get()
                tbar.set_description( 'pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
Hang Zhang's avatar
Hang Zhang committed
77
        else:
Hang Zhang's avatar
Hang Zhang committed
78
79
80
81
82
            with torch.no_grad():
                outputs = evaluator.parallel_forward(image)
                predicts = [testset.make_pred(torch.max(output, 1)[1].cpu().numpy())
                            for output in outputs]
            for predict, impath in zip(predicts, dst):
Hang Zhang's avatar
Hang Zhang committed
83
84
85
86
87
88
89
90
91
                mask = utils.get_mask_pallete(predict, args.dataset)
                outname = os.path.splitext(impath)[0] + '.png'
                mask.save(os.path.join(outdir, outname))

if __name__ == "__main__":
    args = Options().parse()
    torch.manual_seed(args.seed)
    args.test_batch_size = torch.cuda.device_count()
    test(args)