test_rendering_meshes.py 13 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


"""
Sanity checks for output images from the renderer.
"""
import unittest
from pathlib import Path
9
10

import numpy as np
facebook-github-bot's avatar
facebook-github-bot committed
11
12
import torch
from PIL import Image
13
from pytorch3d.io import load_objs_as_meshes
14
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
facebook-github-bot's avatar
facebook-github-bot committed
15
16
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials
17
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
facebook-github-bot's avatar
facebook-github-bot committed
18
19
20
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import (
    BlendParams,
21
    HardFlatShader,
Patrick Labatut's avatar
Patrick Labatut committed
22
    HardGouraudShader,
23
24
25
    HardPhongShader,
    SoftSilhouetteShader,
    TexturedSoftPhongShader,
facebook-github-bot's avatar
facebook-github-bot committed
26
27
28
29
30
)
from pytorch3d.renderer.mesh.texturing import Textures
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere

31

Nikhila Ravi's avatar
Nikhila Ravi committed
32
# If DEBUG=True, save out images generated in the tests for debugging.
facebook-github-bot's avatar
facebook-github-bot committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# All saved images have prefix DEBUG_
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"


def load_rgb_image(filename, data_dir=DATA_DIR):
    filepath = data_dir / filename
    with Image.open(filepath) as raw_image:
        image = torch.from_numpy(np.array(raw_image) / 255.0)
    image = image.to(dtype=torch.float32)
    return image[..., :3]


class TestRenderingMeshes(unittest.TestCase):
    def test_simple_sphere(self, elevated_camera=False):
        """
Patrick Labatut's avatar
Patrick Labatut committed
49
        Test output of phong and gouraud shading matches a reference image using
facebook-github-bot's avatar
facebook-github-bot committed
50
51
52
53
54
55
56
57
58
59
60
61
62
        the default values for the light sources.

        Args:
            elevated_camera: Defines whether the camera observing the scene should
                           have an elevation of 45 degrees.
        """
        device = torch.device("cuda:0")

        # Init mesh
        sphere_mesh = ico_sphere(5, device)
        verts_padded = sphere_mesh.verts_padded()
        faces_padded = sphere_mesh.faces_padded()
        textures = Textures(verts_rgb=torch.ones_like(verts_padded))
63
        sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
facebook-github-bot's avatar
facebook-github-bot committed
64
65
66

        # Init rasterizer settings
        if elevated_camera:
67
68
            # Elevated and rotated camera
            R, T = look_at_view_transform(dist=2.7, elev=45.0, azim=45.0)
facebook-github-bot's avatar
facebook-github-bot committed
69
            postfix = "_elevated_camera"
70
71
            # If y axis is up, the spot of light should
            # be on the bottom left of the sphere.
facebook-github-bot's avatar
facebook-github-bot committed
72
        else:
73
            # No elevation or azimuth rotation
facebook-github-bot's avatar
facebook-github-bot committed
74
75
76
77
78
79
80
            R, T = look_at_view_transform(2.7, 0.0, 0.0)
            postfix = ""
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)

        # Init shader settings
        materials = Materials(device=device)
        lights = PointLights(device=device)
81
82
83
84
85
        lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]

        raster_settings = RasterizationSettings(
            image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
        )
86
        rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
Nikhila Ravi's avatar
Nikhila Ravi committed
87
88
89
90
91
92
93
94

        # Test several shaders
        shaders = {
            "phong": HardPhongShader,
            "gouraud": HardGouraudShader,
            "flat": HardFlatShader,
        }
        for (name, shader_init) in shaders.items():
95
            shader = shader_init(lights=lights, cameras=cameras, materials=materials)
Nikhila Ravi's avatar
Nikhila Ravi committed
96
97
98
99
100
101
102
103
104
105
106
            renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
            images = renderer(sphere_mesh)
            filename = "simple_sphere_light_%s%s.png" % (name, postfix)
            image_ref = load_rgb_image("test_%s" % filename)
            rgb = images[0, ..., :3].squeeze().cpu()
            if DEBUG:
                filename = "DEBUG_" % filename
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
                    DATA_DIR / filename
                )
            self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
facebook-github-bot's avatar
facebook-github-bot committed
107

108
109
110
111
112
113
        ########################################################
        # Move the light to the +z axis in world space so it is
        # behind the sphere. Note that +Z is in, +Y up,
        # +X left for both world and camera space.
        ########################################################
        lights.location[..., 2] = -2.0
Nikhila Ravi's avatar
Nikhila Ravi committed
114
115
116
        phong_shader = HardPhongShader(
            lights=lights, cameras=cameras, materials=materials
        )
117
        phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
Nikhila Ravi's avatar
Nikhila Ravi committed
118
        images = phong_renderer(sphere_mesh, lights=lights)
facebook-github-bot's avatar
facebook-github-bot committed
119
120
        rgb = images[0, ..., :3].squeeze().cpu()
        if DEBUG:
121
            filename = "DEBUG_simple_sphere_dark%s.png" % postfix
facebook-github-bot's avatar
facebook-github-bot committed
122
            Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
123
                DATA_DIR / filename
facebook-github-bot's avatar
facebook-github-bot committed
124
125
126
            )

        # Load reference image
127
        image_ref_phong_dark = load_rgb_image("test_simple_sphere_dark%s.png" % postfix)
facebook-github-bot's avatar
facebook-github-bot committed
128
129
130
131
        self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))

    def test_simple_sphere_elevated_camera(self):
        """
Patrick Labatut's avatar
Patrick Labatut committed
132
        Test output of phong and gouraud shading matches a reference image using
facebook-github-bot's avatar
facebook-github-bot committed
133
134
135
136
137
138
139
140
        the default values for the light sources.

        The rendering is performed with a camera that has non-zero elevation.
        """
        self.test_simple_sphere(elevated_camera=True)

    def test_simple_sphere_batched(self):
        """
Nikhila Ravi's avatar
Nikhila Ravi committed
141
142
        Test a mesh with vertex textures can be extended to form a batch, and
        is rendered correctly with Phong, Gouraud and Flat Shaders.
facebook-github-bot's avatar
facebook-github-bot committed
143
        """
Nikhila Ravi's avatar
Nikhila Ravi committed
144
        batch_size = 20
facebook-github-bot's avatar
facebook-github-bot committed
145
146
        device = torch.device("cuda:0")

Nikhila Ravi's avatar
Nikhila Ravi committed
147
        # Init mesh with vertex textures.
facebook-github-bot's avatar
facebook-github-bot committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        sphere_meshes = ico_sphere(5, device).extend(batch_size)
        verts_padded = sphere_meshes.verts_padded()
        faces_padded = sphere_meshes.faces_padded()
        textures = Textures(verts_rgb=torch.ones_like(verts_padded))
        sphere_meshes = Meshes(
            verts=verts_padded, faces=faces_padded, textures=textures
        )

        # Init rasterizer settings
        dist = torch.tensor([2.7]).repeat(batch_size).to(device)
        elev = torch.zeros_like(dist)
        azim = torch.zeros_like(dist)
        R, T = look_at_view_transform(dist, elev, azim)
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
        raster_settings = RasterizationSettings(
            image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
        )

        # Init shader settings
        materials = Materials(device=device)
        lights = PointLights(device=device)
169
        lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
facebook-github-bot's avatar
facebook-github-bot committed
170
171

        # Init renderer
172
        rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
Nikhila Ravi's avatar
Nikhila Ravi committed
173
174
175
176
177
178
        shaders = {
            "phong": HardGouraudShader,
            "gouraud": HardGouraudShader,
            "flat": HardFlatShader,
        }
        for (name, shader_init) in shaders.items():
179
            shader = shader_init(lights=lights, cameras=cameras, materials=materials)
Nikhila Ravi's avatar
Nikhila Ravi committed
180
181
182
183
184
185
            renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
            images = renderer(sphere_meshes)
            image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name)
            for i in range(batch_size):
                rgb = images[i, ..., :3].squeeze().cpu()
                self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
facebook-github-bot's avatar
facebook-github-bot committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    def test_silhouette_with_grad(self):
        """
        Test silhouette blending. Also check that gradient calculation works.
        """
        device = torch.device("cuda:0")
        ref_filename = "test_silhouette.png"
        image_ref_filename = DATA_DIR / ref_filename
        sphere_mesh = ico_sphere(5, device)
        verts, faces = sphere_mesh.get_mesh_verts_faces(0)
        sphere_mesh = Meshes(verts=[verts], faces=[faces])

        blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=512,
            blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
            faces_per_pixel=80,
            bin_size=0,
        )

        # Init rasterizer settings
207
        R, T = look_at_view_transform(2.7, 0, 0)
facebook-github-bot's avatar
facebook-github-bot committed
208
209
210
211
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)

        # Init renderer
        renderer = MeshRenderer(
212
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
213
            shader=SoftSilhouetteShader(blend_params=blend_params),
facebook-github-bot's avatar
facebook-github-bot committed
214
215
216
217
218
        )
        images = renderer(sphere_mesh)
        alpha = images[0, ..., 3].squeeze().cpu()
        if DEBUG:
            Image.fromarray((alpha.numpy() * 255).astype(np.uint8)).save(
219
                DATA_DIR / "DEBUG_silhouette.png"
facebook-github-bot's avatar
facebook-github-bot committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
            )

        with Image.open(image_ref_filename) as raw_image_ref:
            image_ref = torch.from_numpy(np.array(raw_image_ref))
        image_ref = image_ref.to(dtype=torch.float32) / 255.0
        self.assertTrue(torch.allclose(alpha, image_ref, atol=0.055))

        # Check grad exist
        verts.requires_grad = True
        sphere_mesh = Meshes(verts=[verts], faces=[faces])
        images = renderer(sphere_mesh)
        images[0, ...].sum().backward()
        self.assertIsNotNone(verts.grad)

    def test_texture_map(self):
        """
236
237
        Test a mesh with a texture map is loaded and rendered correctly.
        The pupils in the eyes of the cow should always be looking to the left.
facebook-github-bot's avatar
facebook-github-bot committed
238
239
        """
        device = torch.device("cuda:0")
240
        DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
facebook-github-bot's avatar
facebook-github-bot committed
241
242
243
        obj_filename = DATA_DIR / "cow_mesh/cow.obj"

        # Load mesh + texture
244
        mesh = load_objs_as_meshes([obj_filename], device=device)
facebook-github-bot's avatar
facebook-github-bot committed
245
246

        # Init rasterizer settings
247
        R, T = look_at_view_transform(2.7, 0, 0)
facebook-github-bot's avatar
facebook-github-bot committed
248
249
250
251
252
253
254
255
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
        raster_settings = RasterizationSettings(
            image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
        )

        # Init shader settings
        materials = Materials(device=device)
        lights = PointLights(device=device)
256
257
258
259

        # Place light behind the cow in world space. The front of
        # the cow is facing the -z direction.
        lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
facebook-github-bot's avatar
facebook-github-bot committed
260
261
262

        # Init renderer
        renderer = MeshRenderer(
263
            rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
264
            shader=TexturedSoftPhongShader(
facebook-github-bot's avatar
facebook-github-bot committed
265
266
267
268
269
270
271
                lights=lights, cameras=cameras, materials=materials
            ),
        )
        images = renderer(mesh)
        rgb = images[0, ..., :3].squeeze().cpu()

        # Load reference image
272
        image_ref = load_rgb_image("test_texture_map_back.png")
facebook-github-bot's avatar
facebook-github-bot committed
273
274
275

        if DEBUG:
            Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
276
                DATA_DIR / "DEBUG_texture_map_back.png"
facebook-github-bot's avatar
facebook-github-bot committed
277
278
279
280
281
            )

        self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))

        # Check grad exists
282
        [verts] = mesh.verts_list()
facebook-github-bot's avatar
facebook-github-bot committed
283
        verts.requires_grad = True
284
        mesh2 = Meshes(verts=[verts], faces=mesh.faces_list(), textures=mesh.textures)
285
        images = renderer(mesh2)
facebook-github-bot's avatar
facebook-github-bot committed
286
287
        images[0, ...].sum().backward()
        self.assertIsNotNone(verts.grad)
288

289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        ##########################################
        # Check rendering of the front of the cow
        ##########################################

        R, T = look_at_view_transform(2.7, 0, 180)
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)

        # Move light to the front of the cow in world space
        lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
        images = renderer(mesh, cameras=cameras, lights=lights)
        rgb = images[0, ..., :3].squeeze().cpu()

        # Load reference image
        image_ref = load_rgb_image("test_texture_map_front.png")

        if DEBUG:
            Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
                DATA_DIR / "DEBUG_texture_map_front.png"
            )

309
310
311
        #################################
        # Add blurring to rasterization
        #################################
312
313
        R, T = look_at_view_transform(2.7, 0, 180)
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
314
315
316
317
318
319
320
321
322
323
        blend_params = BlendParams(sigma=5e-4, gamma=1e-4)
        raster_settings = RasterizationSettings(
            image_size=512,
            blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
            faces_per_pixel=100,
            bin_size=0,
        )

        images = renderer(
            mesh.clone(),
324
            cameras=cameras,
325
326
327
328
329
330
331
332
333
334
335
336
337
338
            raster_settings=raster_settings,
            blend_params=blend_params,
        )
        rgb = images[0, ..., :3].squeeze().cpu()

        # Load reference image
        image_ref = load_rgb_image("test_blurry_textured_rendering.png")

        if DEBUG:
            Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
                DATA_DIR / "DEBUG_blurry_textured_rendering.png"
            )

        self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))