pcontext.py 4.13 KB
Newer Older
Zhang's avatar
v0.4.2  
Zhang committed
1
2
3
4
5
6
7
8
9
10
11
###########################################################################
# Created by: Hang Zhang
# Email: zhang.hang@rutgers.edu
# Copyright (c) 2017
###########################################################################

from PIL import Image, ImageOps, ImageFilter
import os
import math
import random
import numpy as np
Hang Zhang's avatar
Hang Zhang committed
12
from tqdm import trange
Zhang's avatar
v0.4.2  
Zhang committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

import torch
from .base import BaseDataset

class ContextSegmentation(BaseDataset):
    BASE_DIR = 'VOCdevkit/VOC2010'
    NUM_CLASS = 59
    def __init__(self, root=os.path.expanduser('~/.encoding/data'), split='train',
                 mode=None, transform=None, target_transform=None):
        super(ContextSegmentation, self).__init__(
            root, split, mode, transform, target_transform)
        from detail import Detail
        #from detail import mask
        root = os.path.join(root, self.BASE_DIR)
        annFile = os.path.join(root, 'trainval_merged.json')
        imgDir = os.path.join(root, 'JPEGImages')
Hang Zhang's avatar
Hang Zhang committed
29
        mask_file = os.path.join(root, self.split+'.pth')
Zhang's avatar
v0.4.2  
Zhang committed
30
        # training mode
Hang Zhang's avatar
Hang Zhang committed
31
        self.detail = Detail(annFile, imgDir, split)
Zhang's avatar
v0.4.2  
Zhang committed
32
33
34
        self.transform = transform
        self.target_transform = target_transform
        self.ids = self.detail.getImgs()
Hang Zhang's avatar
Hang Zhang committed
35
        # generate masks
Zhang's avatar
v0.4.2  
Zhang committed
36
37
38
39
40
41
42
        self._mapping = np.sort(np.array([
            0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 
            23, 397, 25, 284, 158, 159, 416, 33, 162, 420, 454, 295, 296, 
            427, 44, 45, 46, 308, 59, 440, 445, 31, 232, 65, 354, 424, 
            68, 326, 72, 458, 34, 207, 80, 355, 85, 347, 220, 349, 360, 
            98, 187, 104, 105, 366, 189, 368, 113, 115]))
        self._key = np.array(range(len(self._mapping))).astype('uint8')
Hang Zhang's avatar
Hang Zhang committed
43
44
45
46
        if os.path.exists(mask_file):
            self.masks = torch.load(mask_file)
        else:
            self.masks = self._preprocess(mask_file)
Zhang's avatar
v0.4.2  
Zhang committed
47
48
49
50
51
52
53
54
55
56

    def _class_to_index(self, mask):
        # assert the values
        values = np.unique(mask)
        #assert(values.size > 1)
        for i in range(len(values)):
            assert(values[i] in self._mapping)
        index = np.digitize(mask.ravel(), self._mapping, right=True)
        return self._key[index].reshape(mask.shape)

Hang Zhang's avatar
Hang Zhang committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    def _preprocess(self, mask_file):
        masks = {}
        tbar = trange(len(self.ids))
        print("Preprocessing mask, this will take a while." + \
            "But don't worry, it only run once for each split.")
        for i in tbar:
            img_id = self.ids[i]
            mask = Image.fromarray(self._class_to_index(
                self.detail.getMask(img_id)))
            masks[img_id['image_id']] = mask
            tbar.set_description("Preprocessing masks {}".format(img_id['image_id']))
        torch.save(masks, mask_file)
        return masks

Zhang's avatar
v0.4.2  
Zhang committed
71
72
73
74
    def __getitem__(self, index):
        img_id = self.ids[index]
        path = img_id['file_name']
        iid = img_id['image_id']
Hang Zhang's avatar
Hang Zhang committed
75
        img = Image.open(os.path.join(self.detail.img_folder, path)).convert('RGB')
Zhang's avatar
v0.4.2  
Zhang committed
76
77
78
79
80
        if self.mode == 'test':
            if self.transform is not None:
                img = self.transform(img)
            return img, os.path.basename(path)
        # convert mask to 60 categories
Hang Zhang's avatar
Hang Zhang committed
81
82
83
        #mask = Image.fromarray(self._class_to_index(
        #    self.detail.getMask(img_id)))
        mask = self.masks[iid]
Zhang's avatar
v0.4.2  
Zhang committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        # synchrosized transform
        if self.mode == 'train':
            img, mask = self._sync_transform(img, mask)
        elif self.mode == 'val':
            img, mask = self._val_sync_transform(img, mask)
        else:
            assert self.mode == 'testval'
            mask = self._mask_transform(mask)
        # general resize, normalize and toTensor
        if self.transform is not None:
            #print("transform for input")
            img = self.transform(img)
        if self.target_transform is not None:
            #print("transform for label")
            mask = self.target_transform(mask)
        return img, mask

    def _mask_transform(self, mask):
        target = np.array(mask).astype('int32') - 1
        return torch.from_numpy(target).long()

    def __len__(self):
        return len(self.ids)

    @property
    def pred_offset(self):
        return 1