blending.py 8.51 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


4
from typing import NamedTuple, Sequence
5

facebook-github-bot's avatar
facebook-github-bot committed
6
import torch
7
8

# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
9
from pytorch3d import _C
facebook-github-bot's avatar
facebook-github-bot committed
10

11

facebook-github-bot's avatar
facebook-github-bot committed
12
13
14
15
16
17
18
19
20
# Example functions for blending the top K colors per pixel using the outputs
# from rasterization.
# NOTE: All blending function should return an RGBA image per batch element


# Data class to store blending params with defaults
class BlendParams(NamedTuple):
    sigma: float = 1e-4
    gamma: float = 1e-4
21
    background_color: Sequence = (1.0, 1.0, 1.0)
facebook-github-bot's avatar
facebook-github-bot committed
22
23


24
def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
facebook-github-bot's avatar
facebook-github-bot committed
25
26
27
28
29
30
31
32
33
34
35
36
    """
    Naive blending of top K faces to return an RGBA image
      - **RGB** - choose color of the closest point i.e. K=0
      - **A** - 1.0

    Args:
        colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
        fragments: the outputs of rasterization. From this we use
            - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
              of the faces (in the packed representation) which
              overlap each pixel in the image. This is used to
              determine the output shape.
37
38
        blend_params: BlendParams instance that contains a background_color
        field specifying the color for the background
facebook-github-bot's avatar
facebook-github-bot committed
39
40
41
42
43
    Returns:
        RGBA pixel_colors: (N, H, W, 4)
    """
    N, H, W, K = fragments.pix_to_face.shape
    device = fragments.pix_to_face.device
44
45
46
47

    # Mask for the background.
    is_background = fragments.pix_to_face[..., 0] < 0  # (N, H, W)

48
49
50
51
    if torch.is_tensor(blend_params.background_color):
        background_color = blend_params.background_color
    else:
        background_color = colors.new_tensor(blend_params.background_color)  # (3)
52
53
54
55
56
57
58
59
60
61
62
63
64

    # Find out how much background_color needs to be expanded to be used for masked_scatter.
    num_background_pixels = is_background.sum()

    # Set background color.
    pixel_colors = colors[..., 0, :].masked_scatter(
        is_background[..., None],
        background_color[None, :].expand(num_background_pixels, -1),
    )  # (N, H, W, 3)

    # Concat with the alpha channel.
    alpha = torch.ones((N, H, W, 1), dtype=colors.dtype, device=device)
    return torch.cat([pixel_colors, alpha], dim=-1)  # (N, H, W, 4)
facebook-github-bot's avatar
facebook-github-bot committed
65
66


67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Wrapper for the C++/CUDA Implementation of sigmoid alpha blend.
class _SigmoidAlphaBlend(torch.autograd.Function):
    @staticmethod
    def forward(ctx, dists, pix_to_face, sigma):
        alphas = _C.sigmoid_alpha_blend(dists, pix_to_face, sigma)
        ctx.save_for_backward(dists, pix_to_face, alphas)
        ctx.sigma = sigma
        return alphas

    @staticmethod
    def backward(ctx, grad_alphas):
        dists, pix_to_face, alphas = ctx.saved_tensors
        sigma = ctx.sigma
        grad_dists = _C.sigmoid_alpha_blend_backward(
            grad_alphas, alphas, dists, pix_to_face, sigma
        )
        return grad_dists, None, None


# pyre-fixme[16]: `_SigmoidAlphaBlend` has no attribute `apply`.
_sigmoid_alpha = _SigmoidAlphaBlend.apply


facebook-github-bot's avatar
facebook-github-bot committed
90
91
92
93
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
    """
    Silhouette blending to return an RGBA image
      - **RGB** - choose color of the closest point.
Nikhila Ravi's avatar
Nikhila Ravi committed
94
      - **A** - blend based on the 2D distance based probability map [1].
facebook-github-bot's avatar
facebook-github-bot committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108

    Args:
        colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
        fragments: the outputs of rasterization. From this we use
            - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
              of the faces (in the packed representation) which
              overlap each pixel in the image.
            - dists: FloatTensor of shape (N, H, W, K) specifying
              the 2D euclidean distance from the center of each pixel
              to each of the top K overlapping faces.

    Returns:
        RGBA pixel_colors: (N, H, W, 4)

Nikhila Ravi's avatar
Nikhila Ravi committed
109
    [1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
facebook-github-bot's avatar
facebook-github-bot committed
110
111
112
        3D Reasoning', ICCV 2019
    """
    N, H, W, K = fragments.pix_to_face.shape
113
    pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
114
115
116
    pixel_colors[..., :3] = colors[..., 0, :]
    alpha = _sigmoid_alpha(fragments.dists, fragments.pix_to_face, blend_params.sigma)
    pixel_colors[..., 3] = alpha
117
    return pixel_colors
facebook-github-bot's avatar
facebook-github-bot committed
118
119


120
121
122
def softmax_rgb_blend(
    colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100
) -> torch.Tensor:
facebook-github-bot's avatar
facebook-github-bot committed
123
124
    """
    RGB and alpha channel blending to return an RGBA image based on the method
Nikhila Ravi's avatar
Nikhila Ravi committed
125
    proposed in [1]
facebook-github-bot's avatar
facebook-github-bot committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
      - **RGB** - blend the colors based on the 2D distance based probability map and
        relative z distances.
      - **A** - blend based on the 2D distance based probability map.

    Args:
        colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
        fragments: namedtuple with outputs of rasterization. We use properties
            - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
              of the faces (in the packed representation) which
              overlap each pixel in the image.
            - dists: FloatTensor of shape (N, H, W, K) specifying
              the 2D euclidean distance from the center of each pixel
              to each of the top K overlapping faces.
            - zbuf: FloatTensor of shape (N, H, W, K) specifying
              the interpolated depth from each pixel to to each of the
              top K overlapping faces.
        blend_params: instance of BlendParams dataclass containing properties
            - sigma: float, parameter which controls the width of the sigmoid
              function used to calculate the 2D distance based probability.
              Sigma controls the sharpness of the edges of the shape.
            - gamma: float, parameter which controls the scaling of the
              exponential function used to control the opacity of the color.
            - background_color: (3) element list/tuple/torch.Tensor specifying
              the RGB values for the background color.
150
151
        znear: float, near clipping plane in the z direction
        zfar: float, far clipping plane in the z direction
facebook-github-bot's avatar
facebook-github-bot committed
152
153
154
155
156
157
158

    Returns:
        RGBA pixel_colors: (N, H, W, 4)

    [0] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for
    Image-based 3D Reasoning'
    """
159

facebook-github-bot's avatar
facebook-github-bot committed
160
161
    N, H, W, K = fragments.pix_to_face.shape
    device = fragments.pix_to_face.device
162
    pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
facebook-github-bot's avatar
facebook-github-bot committed
163
164
    background = blend_params.background_color
    if not torch.is_tensor(background):
165
        background = torch.tensor(background, dtype=torch.float32, device=device)
facebook-github-bot's avatar
facebook-github-bot committed
166

Nikhila Ravi's avatar
Nikhila Ravi committed
167
168
    # Weight for background color
    eps = 1e-10
facebook-github-bot's avatar
facebook-github-bot committed
169
170
171
172
173
174
175

    # Mask for padded pixels.
    mask = fragments.pix_to_face >= 0

    # Sigmoid probability map based on the distance of the pixel to the face.
    prob_map = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask

Nikhila Ravi's avatar
Nikhila Ravi committed
176
177
178
179
180
    # The cumulative product ensures that alpha will be 0.0 if at least 1
    # face fully covers the pixel as for that face, prob will be 1.0.
    # This results in a multiplication by 0.0 because of the (1.0 - prob)
    # term. Therefore 1.0 - alpha will be 1.0.
    alpha = torch.prod((1.0 - prob_map), dim=-1)
facebook-github-bot's avatar
facebook-github-bot committed
181
182
183
184

    # Weights for each face. Adjust the exponential by the max z to prevent
    # overflow. zbuf shape (N, H, W, K), find max over K.
    # TODO: there may still be some instability in the exponent calculation.
185

facebook-github-bot's avatar
facebook-github-bot committed
186
    z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
187
188
    # pyre-fixme[16]: `Tuple` has no attribute `values`.
    # pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
189
    z_inv_max = torch.max(z_inv, dim=-1).values[..., None].clamp(min=eps)
190
    # pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
facebook-github-bot's avatar
facebook-github-bot committed
191
192
    weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)

Nikhila Ravi's avatar
Nikhila Ravi committed
193
194
    # Also apply exp normalize trick for the background color weight.
    # Clamp to ensure delta is never 0.
195
196
    # pyre-fixme[20]: Argument `max` expected.
    # pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
Nikhila Ravi's avatar
Nikhila Ravi committed
197
198
    delta = torch.exp((eps - z_inv_max) / blend_params.gamma).clamp(min=eps)

facebook-github-bot's avatar
facebook-github-bot committed
199
200
201
202
203
    # Normalize weights.
    # weights_num shape: (N, H, W, K). Sum over K and divide through by the sum.
    denom = weights_num.sum(dim=-1)[..., None] + delta

    # Sum: weights * textures + background color
Nikhila Ravi's avatar
Nikhila Ravi committed
204
205
206
    weighted_colors = (weights_num[..., None] * colors).sum(dim=-2)
    weighted_background = delta * background
    pixel_colors[..., :3] = (weighted_colors + weighted_background) / denom
207
    pixel_colors[..., 3] = 1.0 - alpha
facebook-github-bot's avatar
facebook-github-bot committed
208

209
    return pixel_colors