test_shapenet_core.py 4.98 KB
Newer Older
Luya Gao's avatar
Luya Gao committed
1
2
3
4
5
6
7
8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
Sanity checks for loading ShapeNet Core v1.
"""
import os
import random
import unittest
import warnings
Luya Gao's avatar
Luya Gao committed
9
from pathlib import Path
Luya Gao's avatar
Luya Gao committed
10

Luya Gao's avatar
Luya Gao committed
11
import numpy as np
Luya Gao's avatar
Luya Gao committed
12
import torch
Luya Gao's avatar
Luya Gao committed
13
14
from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
15
from pytorch3d.datasets import ShapeNetCore
Luya Gao's avatar
Luya Gao committed
16
17
18
19
20
21
from pytorch3d.renderer import (
    OpenGLPerspectiveCameras,
    PointLights,
    RasterizationSettings,
    look_at_view_transform,
)
Luya Gao's avatar
Luya Gao committed
22
23
24


SHAPENET_PATH = None
Luya Gao's avatar
Luya Gao committed
25
26
27
28
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
Luya Gao's avatar
Luya Gao committed
29
30
31
32


class TestShapenetCore(TestCaseMixin, unittest.TestCase):
    def test_load_shapenet_core(self):
Luya Gao's avatar
Luya Gao committed
33
34
        # Setup
        device = torch.device("cuda:0")
Luya Gao's avatar
Luya Gao committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48

        # The ShapeNet dataset is not provided in the repo.
        # Download this separately and update the `shapenet_path`
        # with the location of the dataset in order to run this test.
        if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH):
            url = "https://www.shapenet.org/"
            msg = """ShapeNet data not found, download from %s, save it at the path %s,
                update SHAPENET_PATH at the top of the file, and rerun""" % (
                url,
                SHAPENET_PATH,
            )
            warnings.warn(msg)
            return True

Luya Gao's avatar
Luya Gao committed
49
        # Try loading ShapeNetCore with an invalid version number and catch error.
50
51
52
53
        with self.assertRaises(ValueError) as err:
            ShapeNetCore(SHAPENET_PATH, version=3)
        self.assertTrue("Version number must be either 1 or 2." in str(err.exception))

Luya Gao's avatar
Luya Gao committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        # Load ShapeNetCore without specifying any particular categories.
        shapenet_dataset = ShapeNetCore(SHAPENET_PATH)

        # Count the number of grandchildren directories (which should be equal to
        # the total number of objects in the dataset) by walking through the given
        # directory.
        wnsynset_list = [
            wnsynset
            for wnsynset in os.listdir(SHAPENET_PATH)
            if os.path.isdir(os.path.join(SHAPENET_PATH, wnsynset))
        ]
        model_num_list = [
            (len(next(os.walk(os.path.join(SHAPENET_PATH, wnsynset)))[1]))
            for wnsynset in wnsynset_list
        ]
        # Check total number of objects in the dataset is correct.
        self.assertEqual(len(shapenet_dataset), sum(model_num_list))

        # Randomly retrieve an object from the dataset.
        rand_obj = random.choice(shapenet_dataset)
74
        self.assertEqual(len(rand_obj), 5)
Luya Gao's avatar
Luya Gao committed
75
76
77
78
79
80
81
82
        # Check that data types and shapes of items returned by __getitem__ are correct.
        verts, faces = rand_obj["verts"], rand_obj["faces"]
        self.assertTrue(verts.dtype == torch.float32)
        self.assertTrue(faces.dtype == torch.int64)
        self.assertEqual(verts.ndim, 2)
        self.assertEqual(verts.shape[-1], 3)
        self.assertEqual(faces.ndim, 2)
        self.assertEqual(faces.shape[-1], 3)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

        # Load six categories from ShapeNetCore.
        # Specify categories in the form of a combination of offsets and labels.
        shapenet_subset = ShapeNetCore(
            SHAPENET_PATH,
            synsets=[
                "04330267",
                "guitar",
                "02801938",
                "birdhouse",
                "03991062",
                "tower",
            ],
            version=1,
        )
        subset_offsets = [
            "04330267",
            "03467517",
            "02801938",
            "02843684",
            "03991062",
            "04460130",
        ]
        subset_model_nums = [
            (len(next(os.walk(os.path.join(SHAPENET_PATH, offset)))[1]))
            for offset in subset_offsets
        ]
        self.assertEqual(len(shapenet_subset), sum(subset_model_nums))
Luya Gao's avatar
Luya Gao committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

        # Render the first image in the piano category.
        R, T = look_at_view_transform(1.0, 1.0, 90)
        piano_dataset = ShapeNetCore(SHAPENET_PATH, synsets=["piano"])

        cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
        raster_settings = RasterizationSettings(image_size=512)
        lights = PointLights(
            location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
            # TODO: debug the source of the discrepancy in two images when rendering on GPU.
            diffuse_color=((0, 0, 0),),
            specular_color=((0, 0, 0),),
            device=device,
        )
        images = piano_dataset.render(
            0,
            device=device,
            cameras=cameras,
            raster_settings=raster_settings,
            lights=lights,
        )
        rgb = images[0, ..., :3].squeeze().cpu()
        if DEBUG:
            Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
                DATA_DIR / "DEBUG_shapenet_core_render_piano.png"
            )
        image_ref = load_rgb_image("test_shapenet_core_render_piano.png", DATA_DIR)
        self.assertClose(rgb, image_ref, atol=0.05)