"vscode:/vscode.git/clone" did not exist on "1177a80b591294e235966bc3ebc701948e505e7a"
ModelNetDataLoader.py 3.85 KB
Newer Older
esang's avatar
esang committed
1
import os
2
3
4
import warnings

import numpy as np
esang's avatar
esang committed
5
from torch.utils.data import Dataset
6
7

warnings.filterwarnings("ignore")
esang's avatar
esang committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc


def farthest_point_sample(point, npoint):
    """
    Farthest point sampler works as follows:
    1. Initialize the sample set S with a random point
    2. Pick point P not in S, which maximizes the distance d(P, S)
    3. Repeat step 2 until |S| = npoint

    Input:
        xyz: pointcloud data, [N, D]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [npoint, D]
    """
    N, D = point.shape
    xyz = point[:, :3]
    centroids = np.zeros((npoint,))
    distance = np.ones((N,)) * 1e10
    farthest = np.random.randint(0, N)
    for i in range(npoint):
        centroids[i] = farthest
        centroid = xyz[farthest, :]
        dist = np.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = np.argmax(distance, -1)
    point = point[centroids.astype(np.int32)]
    return point


class ModelNetDataLoader(Dataset):
48
49
50
51
52
53
54
55
56
    def __init__(
        self,
        root,
        npoint=1024,
        split="train",
        fps=False,
        normal_channel=True,
        cache_size=15000,
    ):
esang's avatar
esang committed
57
58
        """
        Input:
59
            root: the root path to the local data files
esang's avatar
esang committed
60
61
62
63
64
65
66
67
68
            npoint: number of points from each cloud
            split: which split of the data, 'train' or 'test'
            fps: whether to sample points with farthest point sampler
            normal_channel: whether to use additional channel
            cache_size: the cache size of in-memory point clouds
        """
        self.root = root
        self.npoints = npoint
        self.fps = fps
69
        self.catfile = os.path.join(self.root, "modelnet40_shape_names.txt")
esang's avatar
esang committed
70
71
72
73
74
75

        self.cat = [line.rstrip() for line in open(self.catfile)]
        self.classes = dict(zip(self.cat, range(len(self.cat))))
        self.normal_channel = normal_channel

        shape_ids = {}
76
77
78
79
80
81
82
83
84
85
86
        shape_ids["train"] = [
            line.rstrip()
            for line in open(os.path.join(self.root, "modelnet40_train.txt"))
        ]
        shape_ids["test"] = [
            line.rstrip()
            for line in open(os.path.join(self.root, "modelnet40_test.txt"))
        ]

        assert split == "train" or split == "test"
        shape_names = ["_".join(x.split("_")[0:-1]) for x in shape_ids[split]]
esang's avatar
esang committed
87
        # list of (shape_name, shape_txt_file_path) tuple
88
89
90
91
92
93
94
95
96
        self.datapath = [
            (
                shape_names[i],
                os.path.join(self.root, shape_names[i], shape_ids[split][i])
                + ".txt",
            )
            for i in range(len(shape_ids[split]))
        ]
        print("The size of %s data is %d" % (split, len(self.datapath)))
esang's avatar
esang committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110

        self.cache_size = cache_size
        self.cache = {}

    def __len__(self):
        return len(self.datapath)

    def _get_item(self, index):
        if index in self.cache:
            point_set, cls = self.cache[index]
        else:
            fn = self.datapath[index]
            cls = self.classes[self.datapath[index][0]]
            cls = np.array([cls]).astype(np.int32)
111
            point_set = np.loadtxt(fn[1], delimiter=",").astype(np.float32)
esang's avatar
esang committed
112
113
114
            if self.fps:
                point_set = farthest_point_sample(point_set, self.npoints)
            else:
115
                point_set = point_set[0 : self.npoints, :]
esang's avatar
esang committed
116
117
118
119
120
121
122
123
124
125
126
127
128

            point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])

            if not self.normal_channel:
                point_set = point_set[:, 0:3]

            if len(self.cache) < self.cache_size:
                self.cache[index] = (point_set, cls)

        return point_set, cls

    def __getitem__(self, index):
        return self._get_item(index)