test_shader.py 4.55 KB
Newer Older
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
Patrick Labatut's avatar
Patrick Labatut committed
2
3
4
5
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
6
7
8
9

import unittest

import torch
10
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
11
12
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.shader import (
13
    HardDepthShader,
14
15
16
    HardFlatShader,
    HardGouraudShader,
    HardPhongShader,
17
    SoftDepthShader,
18
    SoftPhongShader,
19
    SplatterPhongShader,
20
21
22
)
from pytorch3d.structures.meshes import Meshes

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
23
24
from .common_testing import TestCaseMixin

25
26

class TestShader(TestCaseMixin, unittest.TestCase):
27
28
    def setUp(self):
        self.shader_classes = [
29
            HardDepthShader,
30
31
32
            HardFlatShader,
            HardGouraudShader,
            HardPhongShader,
33
            SoftDepthShader,
34
            SoftPhongShader,
35
            SplatterPhongShader,
36
37
        ]

38
39
40
41
42
43
44
    def test_to(self):
        cpu_device = torch.device("cpu")
        cuda_device = torch.device("cuda:0")

        R, T = look_at_view_transform()

        for shader_class in self.shader_classes:
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
            for cameras_class in (None, PerspectiveCameras):
                if cameras_class is None:
                    cameras = None
                else:
                    cameras = PerspectiveCameras(device=cpu_device, R=R, T=T)

                cpu_shader = shader_class(device=cpu_device, cameras=cameras)
                if cameras is None:
                    self.assertIsNone(cpu_shader.cameras)
                else:
                    self.assertEqual(cpu_device, cpu_shader.cameras.device)
                self.assertEqual(cpu_device, cpu_shader.materials.device)
                self.assertEqual(cpu_device, cpu_shader.lights.device)

                cuda_shader = cpu_shader.to(cuda_device)
                self.assertIs(cpu_shader, cuda_shader)
                if cameras is None:
                    self.assertIsNone(cuda_shader.cameras)
63
64
                    with self.assertRaisesRegexp(ValueError, "Cameras must be"):
                        cuda_shader._get_cameras()
65
66
                else:
                    self.assertEqual(cuda_device, cuda_shader.cameras.device)
67
                    self.assertIsInstance(cuda_shader._get_cameras(), cameras_class)
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
                self.assertEqual(cuda_device, cuda_shader.materials.device)
                self.assertEqual(cuda_device, cuda_shader.lights.device)

    def test_cameras_check(self):
        verts = torch.tensor(
            [[-1, -1, 0], [1, -1, 1], [1, 1, 0], [-1, 1, 1]], dtype=torch.float32
        )
        faces = torch.tensor([[0, 1, 2], [2, 3, 0]], dtype=torch.int64)
        meshes = Meshes(verts=[verts], faces=[faces])

        pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
        barycentric_coords = torch.tensor(
            [[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype=torch.float32
        ).view(1, 1, 1, 2, -1)
        fragments = Fragments(
            pix_to_face=pix_to_face,
            bary_coords=barycentric_coords,
            zbuf=torch.ones_like(pix_to_face),
            dists=torch.ones_like(pix_to_face),
        )

89
        for shader_class in self.shader_classes:
90
91
92
93
            shader = shader_class()

            with self.assertRaises(ValueError):
                shader(fragments, meshes)
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    def test_depth_shader(self):
        shader_classes = [
            HardDepthShader,
            SoftDepthShader,
        ]

        verts = torch.tensor(
            [[-1, -1, 0], [1, -1, 1], [1, 1, 0], [-1, 1, 1]], dtype=torch.float32
        )
        faces = torch.tensor([[0, 1, 2], [2, 3, 0]], dtype=torch.int64)
        meshes = Meshes(verts=[verts], faces=[faces])

        pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
        barycentric_coords = torch.tensor(
            [[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype=torch.float32
        ).view(1, 1, 1, 2, -1)
        for faces_per_pixel in [1, 2]:
            fragments = Fragments(
                pix_to_face=pix_to_face[:, :, :, :faces_per_pixel],
                bary_coords=barycentric_coords[:, :, :, :faces_per_pixel],
                zbuf=torch.ones_like(pix_to_face),
                dists=torch.ones_like(pix_to_face),
            )
            R, T = look_at_view_transform()
            cameras = PerspectiveCameras(R=R, T=T)

            for shader_class in shader_classes:
                shader = shader_class()

                out = shader(fragments, meshes, cameras=cameras)
                self.assertEqual(out.shape, (1, 1, 1, 1))