ShapeNet.py 5.59 KB
Newer Older
1
2
import json
import os
esang's avatar
esang committed
3
from zipfile import ZipFile
4

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
5
6
import dgl

7
8
import numpy as np
import tqdm
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
9
from dgl.data.utils import download, get_download_dir
esang's avatar
esang committed
10
from scipy.sparse import csr_matrix
11
12
13
from torch.utils.data import Dataset


esang's avatar
esang committed
14
15
16
17
18
19
20
class ShapeNet(object):
    def __init__(self, num_points=2048, normal_channel=True):
        self.num_points = num_points
        self.normal_channel = normal_channel

        SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
        download_path = get_download_dir()
21
22
23
24
25
26
27
        data_filename = (
            "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
        )
        data_path = os.path.join(
            download_path,
            "shapenetcore_partanno_segmentation_benchmark_v0_normal",
        )
esang's avatar
esang committed
28
29
30
31
32
33
34
35
36
        if not os.path.exists(data_path):
            local_path = os.path.join(download_path, data_filename)
            if not os.path.exists(local_path):
                download(SHAPENET_DOWNLOAD_URL, local_path, verify_ssl=False)
            with ZipFile(local_path) as z:
                z.extractall(path=download_path)

        synset_file = "synsetoffset2category.txt"
        with open(os.path.join(data_path, synset_file)) as f:
37
            synset = [t.split("\n")[0].split("\t") for t in f.readlines()]
esang's avatar
esang committed
38
39
40
        self.synset_dict = {}
        for syn in synset:
            self.synset_dict[syn[1]] = syn[0]
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        self.seg_classes = {
            "Airplane": [0, 1, 2, 3],
            "Bag": [4, 5],
            "Cap": [6, 7],
            "Car": [8, 9, 10, 11],
            "Chair": [12, 13, 14, 15],
            "Earphone": [16, 17, 18],
            "Guitar": [19, 20, 21],
            "Knife": [22, 23],
            "Lamp": [24, 25, 26, 27],
            "Laptop": [28, 29],
            "Motorbike": [30, 31, 32, 33, 34, 35],
            "Mug": [36, 37],
            "Pistol": [38, 39, 40],
            "Rocket": [41, 42, 43],
            "Skateboard": [44, 45, 46],
            "Table": [47, 48, 49],
        }

        train_split_json = "shuffled_train_file_list.json"
        val_split_json = "shuffled_val_file_list.json"
        test_split_json = "shuffled_test_file_list.json"
        split_path = os.path.join(data_path, "train_test_split")
esang's avatar
esang committed
64
65
        with open(os.path.join(split_path, train_split_json)) as f:
            tmp = f.read()
66
67
68
69
            self.train_file_list = [
                os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
                for t in json.loads(tmp)
            ]
esang's avatar
esang committed
70
71
        with open(os.path.join(split_path, val_split_json)) as f:
            tmp = f.read()
72
73
74
75
            self.val_file_list = [
                os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
                for t in json.loads(tmp)
            ]
esang's avatar
esang committed
76
77
        with open(os.path.join(split_path, test_split_json)) as f:
            tmp = f.read()
78
79
80
81
            self.test_file_list = [
                os.path.join(data_path, t.replace("shape_data/", "") + ".txt")
                for t in json.loads(tmp)
            ]
esang's avatar
esang committed
82
83

    def train(self):
84
85
86
        return ShapeNetDataset(
            self, "train", self.num_points, self.normal_channel
        )
esang's avatar
esang committed
87
88

    def valid(self):
89
90
91
        return ShapeNetDataset(
            self, "valid", self.num_points, self.normal_channel
        )
esang's avatar
esang committed
92
93

    def trainval(self):
94
95
96
        return ShapeNetDataset(
            self, "trainval", self.num_points, self.normal_channel
        )
esang's avatar
esang committed
97
98

    def test(self):
99
100
101
102
        return ShapeNetDataset(
            self, "test", self.num_points, self.normal_channel
        )

esang's avatar
esang committed
103
104
105
106
107
108
109
110
111
112
113

class ShapeNetDataset(Dataset):
    def __init__(self, shapenet, mode, num_points, normal_channel=True):
        super(ShapeNetDataset, self).__init__()
        self.mode = mode
        self.num_points = num_points
        if not normal_channel:
            self.dim = 3
        else:
            self.dim = 6

114
        if mode == "train":
esang's avatar
esang committed
115
            self.file_list = shapenet.train_file_list
116
        elif mode == "valid":
esang's avatar
esang committed
117
            self.file_list = shapenet.val_file_list
118
        elif mode == "test":
esang's avatar
esang committed
119
            self.file_list = shapenet.test_file_list
120
        elif mode == "trainval":
esang's avatar
esang committed
121
122
123
124
125
126
127
            self.file_list = shapenet.train_file_list + shapenet.val_file_list
        else:
            raise "Not supported `mode`"

        data_list = []
        label_list = []
        category_list = []
128
        print("Loading data from split " + self.mode)
esang's avatar
esang committed
129
130
        for fn in tqdm.tqdm(self.file_list, ascii=True):
            with open(fn) as f:
131
132
133
134
                data = np.array(
                    [t.split("\n")[0].split(" ") for t in f.readlines()]
                ).astype(np.float)
            data_list.append(data[:, 0 : self.dim])
135
            label_list.append(data[:, 6].astype(int))
136
            category_list.append(shapenet.synset_dict[fn.split("/")[-2]])
esang's avatar
esang committed
137
138
139
140
        self.data = data_list
        self.label = label_list
        self.category = category_list

141
    def translate(self, x, scale=(2 / 3, 3 / 2), shift=(-0.2, 0.2), size=3):
esang's avatar
esang committed
142
143
        xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])
        xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])
144
        x = np.add(np.multiply(x, xyz1), xyz2).astype("float32")
esang's avatar
esang committed
145
146
147
148
149
150
        return x

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
151
152
153
154
        inds = np.random.choice(
            self.data[i].shape[0], self.num_points, replace=True
        )
        x = self.data[i][inds, : self.dim]
esang's avatar
esang committed
155
156
        y = self.label[i][inds]
        cat = self.category[i]
157
        if self.mode == "train":
esang's avatar
esang committed
158
159
            x = self.translate(x, size=self.dim)
        x = x.astype(np.float)
160
        y = y.astype(int)
esang's avatar
esang committed
161
        return x, y, cat