lighting.py 12.2 KB
Newer Older
Patrick Labatut's avatar
Patrick Labatut committed
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
facebook-github-bot's avatar
facebook-github-bot committed
6
7
8
9
10


import torch
import torch.nn.functional as F

11
from ..common.types import Device
facebook-github-bot's avatar
facebook-github-bot committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from .utils import TensorProperties, convert_to_tensors_and_broadcast


def diffuse(normals, color, direction) -> torch.Tensor:
    """
    Calculate the diffuse component of light reflection using Lambert's
    cosine law.

    Args:
        normals: (N, ..., 3) xyz normal vectors. Normals and points are
            expected to have the same shape.
        color: (1, 3) or (N, 3) RGB color of the diffuse component of the light.
        direction: (x,y,z) direction of the light

    Returns:
        colors: (N, ..., 3), same shape as the input points.

    The normals and light direction should be in the same coordinate frame
    i.e. if the points have been transformed from world -> view space then
    the normals and direction should also be in view space.

    NOTE: to use with the packed vertices (i.e. no batch dimension) reformat the
    inputs in the following way.

    .. code-block:: python

        Args:
            normals: (P, 3)
            color: (N, 3)[batch_idx, :] -> (P, 3)
            direction: (N, 3)[batch_idx, :] -> (P, 3)

        Returns:
            colors: (P, 3)

        where batch_idx is of shape (P). For meshes, batch_idx can be:
        meshes.verts_packed_to_mesh_idx() or meshes.faces_packed_to_mesh_idx()
        depending on whether points refers to the vertex coordinates or
        average/interpolated face coordinates.
    """
    # TODO: handle multiple directional lights per batch element.
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
52
    # TODO: handle attenuation.
facebook-github-bot's avatar
facebook-github-bot committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68

    # Ensure color and location have same batch dimension as normals
    normals, color, direction = convert_to_tensors_and_broadcast(
        normals, color, direction, device=normals.device
    )

    # Reshape direction and color so they have all the arbitrary intermediate
    # dimensions as normals. Assume first dim = batch dim and last dim = 3.
    points_dims = normals.shape[1:-1]
    expand_dims = (-1,) + (1,) * len(points_dims) + (3,)
    if direction.shape != normals.shape:
        direction = direction.view(expand_dims)
    if color.shape != normals.shape:
        color = color.view(expand_dims)

    # Renormalize the normals in case they have been interpolated.
69
    # We tried to replace the following with F.cosine_similarity, but it wasn't faster.
facebook-github-bot's avatar
facebook-github-bot committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
    direction = F.normalize(direction, p=2, dim=-1, eps=1e-6)
    angle = F.relu(torch.sum(normals * direction, dim=-1))
    return color * angle[..., None]


def specular(
    points, normals, direction, color, camera_position, shininess
) -> torch.Tensor:
    """
    Calculate the specular component of light reflection.

    Args:
        points: (N, ..., 3) xyz coordinates of the points.
        normals: (N, ..., 3) xyz normal vectors for each point.
        color: (N, 3) RGB color of the specular component of the light.
        direction: (N, 3) vector direction of the light.
        camera_position: (N, 3) The xyz position of the camera.
        shininess: (N)  The specular exponent of the material.

    Returns:
        colors: (N, ..., 3), same shape as the input points.

    The points, normals, camera_position, and direction should be in the same
    coordinate frame i.e. if the points have been transformed from
    world -> view space then the normals, camera_position, and light direction
    should also be in view space.

    To use with a batch of packed points reindex in the following way.
    .. code-block:: python::

        Args:
            points: (P, 3)
            normals: (P, 3)
            color: (N, 3)[batch_idx] -> (P, 3)
            direction: (N, 3)[batch_idx] -> (P, 3)
            camera_position: (N, 3)[batch_idx] -> (P, 3)
            shininess: (N)[batch_idx] -> (P)
        Returns:
            colors: (P, 3)

        where batch_idx is of shape (P). For meshes batch_idx can be:
        meshes.verts_packed_to_mesh_idx() or meshes.faces_packed_to_mesh_idx().
    """
    # TODO: handle multiple directional lights
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
115
    # TODO: attenuate based on inverse squared distance to the light source
facebook-github-bot's avatar
facebook-github-bot committed
116
117
118
119
120
121
122

    if points.shape != normals.shape:
        msg = "Expected points and normals to have the same shape: got %r, %r"
        raise ValueError(msg % (points.shape, normals.shape))

    # Ensure all inputs have same batch dimension as points
    matched_tensors = convert_to_tensors_and_broadcast(
123
        points, color, direction, camera_position, shininess, device=points.device
facebook-github-bot's avatar
facebook-github-bot committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    )
    _, color, direction, camera_position, shininess = matched_tensors

    # Reshape direction and color so they have all the arbitrary intermediate
    # dimensions as points. Assume first dim = batch dim and last dim = 3.
    points_dims = points.shape[1:-1]
    expand_dims = (-1,) + (1,) * len(points_dims)
    if direction.shape != normals.shape:
        direction = direction.view(expand_dims + (3,))
    if color.shape != normals.shape:
        color = color.view(expand_dims + (3,))
    if camera_position.shape != normals.shape:
        camera_position = camera_position.view(expand_dims + (3,))
    if shininess.shape != normals.shape:
        shininess = shininess.view(expand_dims)

    # Renormalize the normals in case they have been interpolated.
141
142
    # We tried a version that uses F.cosine_similarity instead of renormalizing,
    # but it was slower.
facebook-github-bot's avatar
facebook-github-bot committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
    direction = F.normalize(direction, p=2, dim=-1, eps=1e-6)
    cos_angle = torch.sum(normals * direction, dim=-1)
    # No specular highlights if angle is less than 0.
    mask = (cos_angle > 0).to(torch.float32)

    # Calculate the specular reflection.
    view_direction = camera_position - points
    view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
    reflect_direction = -direction + 2 * (cos_angle[..., None] * normals)

    # Cosine of the angle between the reflected light ray and the viewer
    alpha = F.relu(torch.sum(view_direction * reflect_direction, dim=-1)) * mask
    return color * torch.pow(alpha, shininess)[..., None]


class DirectionalLights(TensorProperties):
    def __init__(
        self,
        ambient_color=((0.5, 0.5, 0.5),),
        diffuse_color=((0.3, 0.3, 0.3),),
        specular_color=((0.2, 0.2, 0.2),),
        direction=((0, 1, 0),),
166
        device: Device = "cpu",
facebook-github-bot's avatar
facebook-github-bot committed
167
168
169
170
171
172
173
    ):
        """
        Args:
            ambient_color: RGB color of the ambient component.
            diffuse_color: RGB color of the diffuse component.
            specular_color: RGB color of the specular component.
            direction: (x, y, z) direction vector of the light.
174
            device: Device (as str or torch.device) on which the tensors should be located
facebook-github-bot's avatar
facebook-github-bot committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

        The inputs can each be
            - 3 element tuple/list or list of lists
            - torch tensor of shape (1, 3)
            - torch tensor of shape (N, 3)
        The inputs are broadcast against each other so they all have batch
        dimension N.
        """
        super().__init__(
            device=device,
            ambient_color=ambient_color,
            diffuse_color=diffuse_color,
            specular_color=specular_color,
            direction=direction,
        )
        _validate_light_properties(self)
        if self.direction.shape[-1] != 3:
            msg = "Expected direction to have shape (N, 3); got %r"
            raise ValueError(msg % repr(self.direction.shape))

    def clone(self):
ywang's avatar
ywang committed
196
        other = self.__class__(device=self.device)
facebook-github-bot's avatar
facebook-github-bot committed
197
198
199
200
201
202
203
        return super().clone(other)

    def diffuse(self, normals, points=None) -> torch.Tensor:
        # NOTE: Points is not used but is kept in the args so that the API is
        # the same for directional and point lights. The call sites should not
        # need to know the light type.
        return diffuse(
204
205
206
            normals=normals,
            color=self.diffuse_color,
            direction=self.direction,
facebook-github-bot's avatar
facebook-github-bot committed
207
208
        )

209
    def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
facebook-github-bot's avatar
facebook-github-bot committed
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        return specular(
            points=points,
            normals=normals,
            color=self.specular_color,
            direction=self.direction,
            camera_position=camera_position,
            shininess=shininess,
        )


class PointLights(TensorProperties):
    def __init__(
        self,
        ambient_color=((0.5, 0.5, 0.5),),
        diffuse_color=((0.3, 0.3, 0.3),),
        specular_color=((0.2, 0.2, 0.2),),
        location=((0, 1, 0),),
227
        device: Device = "cpu",
facebook-github-bot's avatar
facebook-github-bot committed
228
229
230
231
232
233
234
    ):
        """
        Args:
            ambient_color: RGB color of the ambient component
            diffuse_color: RGB color of the diffuse component
            specular_color: RGB color of the specular component
            location: xyz position of the light.
235
            device: Device (as str or torch.device) on which the tensors should be located
facebook-github-bot's avatar
facebook-github-bot committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

        The inputs can each be
            - 3 element tuple/list or list of lists
            - torch tensor of shape (1, 3)
            - torch tensor of shape (N, 3)
        The inputs are broadcast against each other so they all have batch
        dimension N.
        """
        super().__init__(
            device=device,
            ambient_color=ambient_color,
            diffuse_color=diffuse_color,
            specular_color=specular_color,
            location=location,
        )
        _validate_light_properties(self)
        if self.location.shape[-1] != 3:
            msg = "Expected location to have shape (N, 3); got %r"
            raise ValueError(msg % repr(self.location.shape))

    def clone(self):
ywang's avatar
ywang committed
257
        other = self.__class__(device=self.device)
facebook-github-bot's avatar
facebook-github-bot committed
258
259
        return super().clone(other)

Nikhila Ravi's avatar
Nikhila Ravi committed
260
261
262
263
264
265
266
267
268
269
270
271
    def reshape_location(self, points) -> torch.Tensor:
        """
        Reshape the location tensor to have dimensions
        compatible with the points which can either be of
        shape (P, 3) or (N, H, W, K, 3).
        """
        if self.location.ndim == points.ndim:
            # pyre-fixme[7]
            return self.location
        # pyre-fixme[29]
        return self.location[:, None, None, None, :]

facebook-github-bot's avatar
facebook-github-bot committed
272
    def diffuse(self, normals, points) -> torch.Tensor:
Nikhila Ravi's avatar
Nikhila Ravi committed
273
274
        location = self.reshape_location(points)
        direction = location - points
275
        return diffuse(normals=normals, color=self.diffuse_color, direction=direction)
facebook-github-bot's avatar
facebook-github-bot committed
276

277
    def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
Nikhila Ravi's avatar
Nikhila Ravi committed
278
279
        location = self.reshape_location(points)
        direction = location - points
facebook-github-bot's avatar
facebook-github-bot committed
280
281
282
283
284
285
286
287
288
289
        return specular(
            points=points,
            normals=normals,
            color=self.specular_color,
            direction=direction,
            camera_position=camera_position,
            shininess=shininess,
        )


Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
290
291
292
293
294
295
296
class AmbientLights(TensorProperties):
    """
    A light object representing the same color of light everywhere.
    By default, this is white, which effectively means lighting is
    not used in rendering.
    """

297
    def __init__(self, *, ambient_color=None, device: Device = "cpu"):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
298
299
300
301
302
303
        """
        If ambient_color is provided, it should be a sequence of
        triples of floats.

        Args:
            ambient_color: RGB color
304
            device: Device (as str or torch.device) on which the tensors should be located
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325

        The ambient_color if provided, should be
            - 3 element tuple/list or list of lists
            - torch tensor of shape (1, 3)
            - torch tensor of shape (N, 3)
        """
        if ambient_color is None:
            ambient_color = ((1.0, 1.0, 1.0),)
        super().__init__(ambient_color=ambient_color, device=device)

    def clone(self):
        other = self.__class__(device=self.device)
        return super().clone(other)

    def diffuse(self, normals, points) -> torch.Tensor:
        return torch.zeros_like(points)

    def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
        return torch.zeros_like(points)


facebook-github-bot's avatar
facebook-github-bot committed
326
327
328
329
330
331
332
def _validate_light_properties(obj):
    props = ("ambient_color", "diffuse_color", "specular_color")
    for n in props:
        t = getattr(obj, n)
        if t.shape[-1] != 3:
            msg = "Expected %s to have shape (N, 3); got %r"
            raise ValueError(msg % (n, t.shape))