blending.py 9.38 KB
Newer Older
Patrick Labatut's avatar
Patrick Labatut committed
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
facebook-github-bot's avatar
facebook-github-bot committed
6
7


Nikhila Ravi's avatar
Nikhila Ravi committed
8
from typing import NamedTuple, Sequence, Union
9

facebook-github-bot's avatar
facebook-github-bot committed
10
import torch
Patrick Labatut's avatar
Patrick Labatut committed
11
from pytorch3d import _C
facebook-github-bot's avatar
facebook-github-bot committed
12

13

facebook-github-bot's avatar
facebook-github-bot committed
14
15
16
17
18
19
# Example functions for blending the top K colors per pixel using the outputs
# from rasterization.
# NOTE: All blending function should return an RGBA image per batch element


class BlendParams(NamedTuple):
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    """
    Data class to store blending params with defaults

    Members:
        sigma (float): Controls the width of the sigmoid function used to
            calculate the 2D distance based probability. Determines the
            sharpness of the edges of the shape.
            Higher => faces have less defined edges.
        gamma (float): Controls the scaling of the exponential function used
            to set the opacity of the color.
            Higher => faces are more transparent.
        background_color: RGB values for the background color as a tuple or
            as a tensor of three floats.
    """

facebook-github-bot's avatar
facebook-github-bot committed
35
36
    sigma: float = 1e-4
    gamma: float = 1e-4
Patrick Labatut's avatar
Patrick Labatut committed
37
    background_color: Union[torch.Tensor, Sequence[float]] = (1.0, 1.0, 1.0)
facebook-github-bot's avatar
facebook-github-bot committed
38
39


Patrick Labatut's avatar
Patrick Labatut committed
40
41
42
def hard_rgb_blend(
    colors: torch.Tensor, fragments, blend_params: BlendParams
) -> torch.Tensor:
facebook-github-bot's avatar
facebook-github-bot committed
43
44
45
46
47
48
49
50
51
52
53
54
    """
    Naive blending of top K faces to return an RGBA image
      - **RGB** - choose color of the closest point i.e. K=0
      - **A** - 1.0

    Args:
        colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
        fragments: the outputs of rasterization. From this we use
            - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
              of the faces (in the packed representation) which
              overlap each pixel in the image. This is used to
              determine the output shape.
55
56
        blend_params: BlendParams instance that contains a background_color
        field specifying the color for the background
facebook-github-bot's avatar
facebook-github-bot committed
57
58
59
60
61
    Returns:
        RGBA pixel_colors: (N, H, W, 4)
    """
    N, H, W, K = fragments.pix_to_face.shape
    device = fragments.pix_to_face.device
62
63
64
65

    # Mask for the background.
    is_background = fragments.pix_to_face[..., 0] < 0  # (N, H, W)

Patrick Labatut's avatar
Patrick Labatut committed
66
67
68
    background_color_ = blend_params.background_color
    if isinstance(background_color_, torch.Tensor):
        background_color = background_color_.to(device)
69
    else:
70
        background_color = colors.new_tensor(background_color_)
71
72
73
74
75
76
77
78
79
80
81

    # Find out how much background_color needs to be expanded to be used for masked_scatter.
    num_background_pixels = is_background.sum()

    # Set background color.
    pixel_colors = colors[..., 0, :].masked_scatter(
        is_background[..., None],
        background_color[None, :].expand(num_background_pixels, -1),
    )  # (N, H, W, 3)

    # Concat with the alpha channel.
82
83
    alpha = (~is_background).type_as(pixel_colors)[..., None]

84
    return torch.cat([pixel_colors, alpha], dim=-1)  # (N, H, W, 4)
facebook-github-bot's avatar
facebook-github-bot committed
85
86


87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Wrapper for the C++/CUDA Implementation of sigmoid alpha blend.
class _SigmoidAlphaBlend(torch.autograd.Function):
    @staticmethod
    def forward(ctx, dists, pix_to_face, sigma):
        alphas = _C.sigmoid_alpha_blend(dists, pix_to_face, sigma)
        ctx.save_for_backward(dists, pix_to_face, alphas)
        ctx.sigma = sigma
        return alphas

    @staticmethod
    def backward(ctx, grad_alphas):
        dists, pix_to_face, alphas = ctx.saved_tensors
        sigma = ctx.sigma
        grad_dists = _C.sigmoid_alpha_blend_backward(
            grad_alphas, alphas, dists, pix_to_face, sigma
        )
        return grad_dists, None, None


# pyre-fixme[16]: `_SigmoidAlphaBlend` has no attribute `apply`.
_sigmoid_alpha = _SigmoidAlphaBlend.apply


Patrick Labatut's avatar
Patrick Labatut committed
110
def sigmoid_alpha_blend(colors, fragments, blend_params: BlendParams) -> torch.Tensor:
facebook-github-bot's avatar
facebook-github-bot committed
111
112
113
    """
    Silhouette blending to return an RGBA image
      - **RGB** - choose color of the closest point.
Nikhila Ravi's avatar
Nikhila Ravi committed
114
      - **A** - blend based on the 2D distance based probability map [1].
facebook-github-bot's avatar
facebook-github-bot committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    Args:
        colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
        fragments: the outputs of rasterization. From this we use
            - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
              of the faces (in the packed representation) which
              overlap each pixel in the image.
            - dists: FloatTensor of shape (N, H, W, K) specifying
              the 2D euclidean distance from the center of each pixel
              to each of the top K overlapping faces.

    Returns:
        RGBA pixel_colors: (N, H, W, 4)

Nikhila Ravi's avatar
Nikhila Ravi committed
129
    [1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
facebook-github-bot's avatar
facebook-github-bot committed
130
131
132
        3D Reasoning', ICCV 2019
    """
    N, H, W, K = fragments.pix_to_face.shape
133
    pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
134
135
136
    pixel_colors[..., :3] = colors[..., 0, :]
    alpha = _sigmoid_alpha(fragments.dists, fragments.pix_to_face, blend_params.sigma)
    pixel_colors[..., 3] = alpha
137
    return pixel_colors
facebook-github-bot's avatar
facebook-github-bot committed
138
139


140
def softmax_rgb_blend(
Patrick Labatut's avatar
Patrick Labatut committed
141
    colors: torch.Tensor,
Nikhila Ravi's avatar
Nikhila Ravi committed
142
    fragments,
Patrick Labatut's avatar
Patrick Labatut committed
143
    blend_params: BlendParams,
Nikhila Ravi's avatar
Nikhila Ravi committed
144
145
    znear: Union[float, torch.Tensor] = 1.0,
    zfar: Union[float, torch.Tensor] = 100,
146
) -> torch.Tensor:
facebook-github-bot's avatar
facebook-github-bot committed
147
148
    """
    RGB and alpha channel blending to return an RGBA image based on the method
Nikhila Ravi's avatar
Nikhila Ravi committed
149
    proposed in [1]
facebook-github-bot's avatar
facebook-github-bot committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
      - **RGB** - blend the colors based on the 2D distance based probability map and
        relative z distances.
      - **A** - blend based on the 2D distance based probability map.

    Args:
        colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
        fragments: namedtuple with outputs of rasterization. We use properties
            - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
              of the faces (in the packed representation) which
              overlap each pixel in the image.
            - dists: FloatTensor of shape (N, H, W, K) specifying
              the 2D euclidean distance from the center of each pixel
              to each of the top K overlapping faces.
            - zbuf: FloatTensor of shape (N, H, W, K) specifying
              the interpolated depth from each pixel to to each of the
              top K overlapping faces.
        blend_params: instance of BlendParams dataclass containing properties
            - sigma: float, parameter which controls the width of the sigmoid
              function used to calculate the 2D distance based probability.
              Sigma controls the sharpness of the edges of the shape.
            - gamma: float, parameter which controls the scaling of the
              exponential function used to control the opacity of the color.
            - background_color: (3) element list/tuple/torch.Tensor specifying
              the RGB values for the background color.
174
175
        znear: float, near clipping plane in the z direction
        zfar: float, far clipping plane in the z direction
facebook-github-bot's avatar
facebook-github-bot committed
176
177
178
179
180
181
182

    Returns:
        RGBA pixel_colors: (N, H, W, 4)

    [0] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for
    Image-based 3D Reasoning'
    """
183

facebook-github-bot's avatar
facebook-github-bot committed
184
185
    N, H, W, K = fragments.pix_to_face.shape
    device = fragments.pix_to_face.device
186
    pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
Patrick Labatut's avatar
Patrick Labatut committed
187
188
189
    background_ = blend_params.background_color
    if not isinstance(background_, torch.Tensor):
        background = torch.tensor(background_, dtype=torch.float32, device=device)
190
    else:
Patrick Labatut's avatar
Patrick Labatut committed
191
        background = background_.to(device)
facebook-github-bot's avatar
facebook-github-bot committed
192

Nikhila Ravi's avatar
Nikhila Ravi committed
193
194
    # Weight for background color
    eps = 1e-10
facebook-github-bot's avatar
facebook-github-bot committed
195
196
197
198
199
200
201

    # Mask for padded pixels.
    mask = fragments.pix_to_face >= 0

    # Sigmoid probability map based on the distance of the pixel to the face.
    prob_map = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask

Nikhila Ravi's avatar
Nikhila Ravi committed
202
203
204
205
206
    # The cumulative product ensures that alpha will be 0.0 if at least 1
    # face fully covers the pixel as for that face, prob will be 1.0.
    # This results in a multiplication by 0.0 because of the (1.0 - prob)
    # term. Therefore 1.0 - alpha will be 1.0.
    alpha = torch.prod((1.0 - prob_map), dim=-1)
facebook-github-bot's avatar
facebook-github-bot committed
207
208
209
210

    # Weights for each face. Adjust the exponential by the max z to prevent
    # overflow. zbuf shape (N, H, W, K), find max over K.
    # TODO: there may still be some instability in the exponent calculation.
211

Nikhila Ravi's avatar
Nikhila Ravi committed
212
213
214
215
216
217
218
    # Reshape to be compatible with (N, H, W, K) values in fragments
    if torch.is_tensor(zfar):
        # pyre-fixme[16]
        zfar = zfar[:, None, None, None]
    if torch.is_tensor(znear):
        znear = znear[:, None, None, None]

facebook-github-bot's avatar
facebook-github-bot committed
219
    z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
220
    z_inv_max = torch.max(z_inv, dim=-1).values[..., None].clamp(min=eps)
facebook-github-bot's avatar
facebook-github-bot committed
221
222
    weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)

Nikhila Ravi's avatar
Nikhila Ravi committed
223
224
    # Also apply exp normalize trick for the background color weight.
    # Clamp to ensure delta is never 0.
225
    # pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
Nikhila Ravi's avatar
Nikhila Ravi committed
226
227
    delta = torch.exp((eps - z_inv_max) / blend_params.gamma).clamp(min=eps)

facebook-github-bot's avatar
facebook-github-bot committed
228
229
230
231
232
    # Normalize weights.
    # weights_num shape: (N, H, W, K). Sum over K and divide through by the sum.
    denom = weights_num.sum(dim=-1)[..., None] + delta

    # Sum: weights * textures + background color
Nikhila Ravi's avatar
Nikhila Ravi committed
233
234
235
    weighted_colors = (weights_num[..., None] * colors).sum(dim=-2)
    weighted_background = delta * background
    pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom
236
    pixel_colors[..., 3] = 1.0 - alpha
facebook-github-bot's avatar
facebook-github-bot committed
237

238
    return pixel_colors