ShapeNet.py 5.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os, json, tqdm
import numpy as np
import dgl
from zipfile import ZipFile
from torch.utils.data import Dataset
from scipy.sparse import csr_matrix
from dgl.data.utils import download, get_download_dir

class ShapeNet(object):
    def __init__(self, num_points=2048, normal_channel=True):
        self.num_points = num_points
        self.normal_channel = normal_channel

        SHAPENET_DOWNLOAD_URL = "https://shapenet.cs.stanford.edu/media/shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
        download_path = get_download_dir()
        data_filename = "shapenetcore_partanno_segmentation_benchmark_v0_normal.zip"
        data_path = os.path.join(download_path, "shapenetcore_partanno_segmentation_benchmark_v0_normal")
        if not os.path.exists(data_path):
            local_path = os.path.join(download_path, data_filename)
            if not os.path.exists(local_path):
                download(SHAPENET_DOWNLOAD_URL, local_path, verify_ssl=False)
            with ZipFile(local_path) as z:
                z.extractall(path=download_path)

        synset_file = "synsetoffset2category.txt"
        with open(os.path.join(data_path, synset_file)) as f:
            synset = [t.split('\n')[0].split('\t') for t in f.readlines()]
        self.synset_dict = {}
        for syn in synset:
            self.synset_dict[syn[1]] = syn[0]
        self.seg_classes = {'Airplane': [0, 1, 2, 3],
                            'Bag': [4, 5],
                            'Cap': [6, 7],
                            'Car': [8, 9, 10, 11],
                            'Chair': [12, 13, 14, 15],
                            'Earphone': [16, 17, 18],
                            'Guitar': [19, 20, 21],
                            'Knife': [22, 23],
                            'Lamp': [24, 25, 26, 27],
                            'Laptop': [28, 29],
                            'Motorbike': [30, 31, 32, 33, 34, 35],
                            'Mug': [36, 37],
                            'Pistol': [38, 39, 40],
                            'Rocket': [41, 42, 43],
                            'Skateboard': [44, 45, 46],
                            'Table': [47, 48, 49]}

        train_split_json = 'shuffled_train_file_list.json'
        val_split_json = 'shuffled_val_file_list.json'
        test_split_json = 'shuffled_test_file_list.json'
        split_path = os.path.join(data_path, 'train_test_split')
        with open(os.path.join(split_path, train_split_json)) as f:
            tmp = f.read()
            self.train_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
        with open(os.path.join(split_path, val_split_json)) as f:
            tmp = f.read()
            self.val_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]
        with open(os.path.join(split_path, test_split_json)) as f:
            tmp = f.read()
            self.test_file_list = [os.path.join(data_path, t.replace('shape_data/', '') + '.txt') for t in json.loads(tmp)]

    def train(self):
        return ShapeNetDataset(self, 'train', self.num_points, self.normal_channel)

    def valid(self):
        return ShapeNetDataset(self, 'valid', self.num_points, self.normal_channel)

    def trainval(self):
        return ShapeNetDataset(self, 'trainval', self.num_points, self.normal_channel)

    def test(self):
        return ShapeNetDataset(self, 'test', self.num_points, self.normal_channel)

class ShapeNetDataset(Dataset):
    def __init__(self, shapenet, mode, num_points, normal_channel=True):
        super(ShapeNetDataset, self).__init__()
        self.mode = mode
        self.num_points = num_points
        if not normal_channel:
            self.dim = 3
        else:
            self.dim = 6

        if mode == 'train':
            self.file_list = shapenet.train_file_list
        elif mode == 'valid':
            self.file_list = shapenet.val_file_list
        elif mode == 'test':
            self.file_list = shapenet.test_file_list
        elif mode == 'trainval':
            self.file_list = shapenet.train_file_list + shapenet.val_file_list
        else:
            raise "Not supported `mode`"

        data_list = []
        label_list = []
        category_list = []
        print('Loading data from split ' + self.mode)
        for fn in tqdm.tqdm(self.file_list, ascii=True):
            with open(fn) as f:
lisj's avatar
lisj committed
101
                data = np.array([t.split('\n')[0].split(' ') for t in f.readlines()]).astype(float)
102
            data_list.append(data[:, 0:self.dim])
lisj's avatar
lisj committed
103
            label_list.append(data[:, 6].astype(int))
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            category_list.append(shapenet.synset_dict[fn.split('/')[-2]])
        self.data = data_list
        self.label = label_list
        self.category = category_list

    def translate(self, x, scale=(2/3, 3/2), shift=(-0.2, 0.2), size=3):
        xyz1 = np.random.uniform(low=scale[0], high=scale[1], size=[size])
        xyz2 = np.random.uniform(low=shift[0], high=shift[1], size=[size])
        x = np.add(np.multiply(x, xyz1), xyz2).astype('float32')
        return x

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        inds = np.random.choice(self.data[i].shape[0], self.num_points, replace=True)
        x = self.data[i][inds,:self.dim]
        y = self.label[i][inds]
        cat = self.category[i]
        if self.mode == 'train':
            x = self.translate(x, size=self.dim)
lisj's avatar
lisj committed
125
126
        x = x.astype(float)
        y = y.astype(int)
127
        return x, y, cat