test_shapenet_core.py 3.39 KB
Newer Older
Luya Gao's avatar
Luya Gao committed
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
Sanity checks for loading ShapeNet Core v1.
"""
import os
import random
import unittest
import warnings

import torch
from common_testing import TestCaseMixin
12
from pytorch3d.datasets import ShapeNetCore
Luya Gao's avatar
Luya Gao committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


SHAPENET_PATH = None


class TestShapenetCore(TestCaseMixin, unittest.TestCase):
    def test_load_shapenet_core(self):

        # The ShapeNet dataset is not provided in the repo.
        # Download this separately and update the `shapenet_path`
        # with the location of the dataset in order to run this test.
        if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH):
            url = "https://www.shapenet.org/"
            msg = """ShapeNet data not found, download from %s, save it at the path %s,
                update SHAPENET_PATH at the top of the file, and rerun""" % (
                url,
                SHAPENET_PATH,
            )
            warnings.warn(msg)
            return True

34
35
36
37
38
        # Try load ShapeNetCore with an invalid version number and catch error.
        with self.assertRaises(ValueError) as err:
            ShapeNetCore(SHAPENET_PATH, version=3)
        self.assertTrue("Version number must be either 1 or 2." in str(err.exception))

Luya Gao's avatar
Luya Gao committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        # Load ShapeNetCore without specifying any particular categories.
        shapenet_dataset = ShapeNetCore(SHAPENET_PATH)

        # Count the number of grandchildren directories (which should be equal to
        # the total number of objects in the dataset) by walking through the given
        # directory.
        wnsynset_list = [
            wnsynset
            for wnsynset in os.listdir(SHAPENET_PATH)
            if os.path.isdir(os.path.join(SHAPENET_PATH, wnsynset))
        ]
        model_num_list = [
            (len(next(os.walk(os.path.join(SHAPENET_PATH, wnsynset)))[1]))
            for wnsynset in wnsynset_list
        ]
        # Check total number of objects in the dataset is correct.
        self.assertEqual(len(shapenet_dataset), sum(model_num_list))

        # Randomly retrieve an object from the dataset.
        rand_obj = random.choice(shapenet_dataset)
59
        self.assertEqual(len(rand_obj), 5)
Luya Gao's avatar
Luya Gao committed
60
61
62
63
64
65
66
67
        # Check that data types and shapes of items returned by __getitem__ are correct.
        verts, faces = rand_obj["verts"], rand_obj["faces"]
        self.assertTrue(verts.dtype == torch.float32)
        self.assertTrue(faces.dtype == torch.int64)
        self.assertEqual(verts.ndim, 2)
        self.assertEqual(verts.shape[-1], 3)
        self.assertEqual(faces.ndim, 2)
        self.assertEqual(faces.shape[-1], 3)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

        # Load six categories from ShapeNetCore.
        # Specify categories in the form of a combination of offsets and labels.
        shapenet_subset = ShapeNetCore(
            SHAPENET_PATH,
            synsets=[
                "04330267",
                "guitar",
                "02801938",
                "birdhouse",
                "03991062",
                "tower",
            ],
            version=1,
        )
        subset_offsets = [
            "04330267",
            "03467517",
            "02801938",
            "02843684",
            "03991062",
            "04460130",
        ]
        subset_model_nums = [
            (len(next(os.walk(os.path.join(SHAPENET_PATH, offset)))[1]))
            for offset in subset_offsets
        ]
        self.assertEqual(len(shapenet_subset), sum(subset_model_nums))