# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import Callable, Optional, Tuple import torch as th import torch.nn.functional as thf from drtk import edge_grad_ext from drtk.interpolate import interpolate from drtk.utils import index th.ops.load_library(edge_grad_ext.__file__) @th.compiler.disable def edge_grad_estimator( v_pix: th.Tensor, vi: th.Tensor, bary_img: th.Tensor, img: th.Tensor, index_img: th.Tensor, v_pix_img_hook: Optional[Callable[[th.Tensor], None]] = None, ) -> th.Tensor: """Makes the rasterized image ``img`` differentiable at visibility discontinuities and backpropagates the gradients to ``v_pix``. This function takes a rasterized image ``img`` that is assumed to be differentiable at continuous regions but not at discontinuities. In some cases, ``img`` may not be differentiable at all. For example, if the image is a rendered segmentation mask, it remains constant at continuous regions, making it non-differentiable. However, ``edge_grad_estimator`` can still compute gradients at the discontinuities with respect to ``v_pix``. The arguments ``bary_img`` and ``index_img`` must correspond exactly to the rasterized image ``img``. Each pixel in ``img`` should correspond to a fragment originated prom primitive specified by ``index_img`` and it should have barycentric coordinates specified by ``bary_img``. This means that with a small change to ``v_pix``, the pixels in ``img`` should change accordingly. A frequent mistake that violates this condition is applying a mask to the rendered image to exclude unwanted regions, which leads to erroneous gradients. The function returns the ``img`` unchanged but with added differentiability at the discontinuities. Note that it is not necessary for the input ``img`` to require gradients, but the returned ``img`` will always require gradients. Args: v_pix (Tensor): Pixel-space vertex coordinates, preserving the original camera-space Z-values. Shape: :math:`(N, V, 3)`. vi (Tensor): Face vertex index list tensor. Shape: :math:`(V, 3)`. bary_img (Tensor): 3D barycentric coordinate image tensor. Shape: :math:`(N, 3, H, W)`. img (Tensor): The rendered image. Shape: :math:`(N, C, H, W)`. index_img (Tensor): Index image tensor. Shape: :math:`(N, H, W)`. v_pix_img_hook (Optional[Callable[[th.Tensor], None]]): An optional backward hook that will be registered to ``v_pix_img``. Useful for examining the generated image space. Default is None. Returns: Tensor: Returns the input ``img`` unchanged. However, the returned image now has added differentiability at visibility discontinuities. This returned image should be used for computing losses Note: It is crucial not to spatially modify the rasterized image before passing it to `edge_grad_estimator`. That stems from the requirement that ``bary_img`` and ``index_img`` must correspond exactly to the rasterized image ``img``. That means that the location of all discontinuities is controlled by ``v_pix`` and can be modified by modifing ``v_pix``. Operations that are allowed, as long as they are differentiable, include: - Pixel-wise MLP - Color mapping - Color correction, gamma correction - Anything that would be indistinguishable from processing fragments independently before their values get assigned to pixels of ``img`` Operations that **must be avoided** before `edge_grad_estimator` include: - Gaussian blur - Warping or deformation - Masking, cropping, or introducing holes There is however, no issue with appling them after `edge_grad_estimator`. If the operation is highly non-linear, it is recommended to perform it before calling :func:`edge_grad_estimator`. All sorts of clipping and clamping (e.g., `x.clamp(min=0.0, max=1.0)`) must also be done before invoking this function. Usage Example:: import torch.nn.functional as thf from drtk import transform, rasterize, render, interpolate, edge_grad_estimator ... v_pix = transform(v, tex, campos, camrot, focal, princpt) index_img = rasterize(v_pix, vi, width=512, height=512) _, bary_img = render(v_pix, vi, index_img) vt_img = interpolate(vt, vti, index_img, bary_img) img = thf.grid_sample( tex, vt_img.permute(0, 2, 3, 1), mode="bilinear", padding_mode="border", align_corners=False ) mask = (index_img != -1)[:, None, :, :] img = img * mask img = edge_grad_estimator( v_pix=v_pix, vi=vi, bary_img=bary_img, img=img, index_img=index_img ) optim.zero_grad() image_loss = loss_func(img, img_gt) image_loss.backward() optim.step() """ # TODO: avoid call to interpolate, use backward kernel of interpolate directly # Doing so will make `edge_grad_estimator` zero-overhead in forward pass # At the moment, value of `v_pix_img` is ignored, and only passed to # edge_grad_estimator so that backward kernel can be called with the computed gradient. v_pix_img = interpolate(v_pix, vi, index_img, bary_img.detach()) img = th.ops.edge_grad_ext.edge_grad_estimator(v_pix, v_pix_img, vi, img, index_img) if v_pix_img_hook is not None: v_pix_img.register_hook(v_pix_img_hook) return img def edge_grad_estimator_ref( v_pix: th.Tensor, vi: th.Tensor, bary_img: th.Tensor, img: th.Tensor, index_img: th.Tensor, v_pix_img_hook: Optional[Callable[[th.Tensor], None]] = None, ) -> th.Tensor: """ Python reference implementation for :func:`drtk.edge_grad_estimator`. """ # could use v_pix_img output from DRTK, but bary_img needs to be detached. v_pix_img = interpolate(v_pix, vi, index_img, bary_img.detach()) # pyre-fixme[16]: `EdgeGradEstimatorFunction` has no attribute `apply`. img = EdgeGradEstimatorFunction.apply(v_pix, v_pix_img, vi, img, index_img) if v_pix_img_hook is not None: v_pix_img.register_hook(v_pix_img_hook) return img class EdgeGradEstimatorFunction(th.autograd.Function): @staticmethod # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. def forward( ctx, v_pix: th.Tensor, v_pix_img: th.Tensor, vi: th.Tensor, img: th.Tensor, index_img: th.Tensor, ) -> th.Tensor: ctx.save_for_backward(v_pix, img, index_img, vi) return img @staticmethod # pyre-fixme[14]: `backward` overrides method defined in `Function` inconsistently. def backward(ctx, grad_output: th.Tensor) -> Tuple[ Optional[th.Tensor], Optional[th.Tensor], Optional[th.Tensor], Optional[th.Tensor], Optional[th.Tensor], ]: # early exit in case geometry is not optimized. if not ctx.needs_input_grad[1]: return None, None, None, grad_output, None v_pix, img, index_img, vi = ctx.saved_tensors x_grad = img[:, :, :, 1:] - img[:, :, :, :-1] y_grad = img[:, :, 1:, :] - img[:, :, :-1, :] l_index = index_img[:, None, :, :-1] r_index = index_img[:, None, :, 1:] t_index = index_img[:, None, :-1, :] b_index = index_img[:, None, 1:, :] x_mask = r_index != l_index y_mask = b_index != t_index x_both_triangles = (r_index != -1) & (l_index != -1) y_both_triangles = (b_index != -1) & (t_index != -1) iimg_clamped = index_img.clamp(min=0).long() # compute barycentric coordinates b = v_pix.shape[0] vi_img = index(vi, iimg_clamped, 0).long() p0 = th.cat( [index(v_pix[i], vi_img[i, ..., 0].data, 0)[None, ...] for i in range(b)], dim=0, ) p1 = th.cat( [index(v_pix[i], vi_img[i, ..., 1].data, 0)[None, ...] for i in range(b)], dim=0, ) p2 = th.cat( [index(v_pix[i], vi_img[i, ..., 2].data, 0)[None, ...] for i in range(b)], dim=0, ) v10 = p1 - p0 v02 = p0 - p2 n = th.cross(v02, v10) px, py = th.meshgrid( th.arange(img.shape[-2], device=v_pix.device), th.arange(img.shape[-1], device=v_pix.device), ) def epsclamp(x: th.Tensor) -> th.Tensor: return th.where(x < 0, x.clamp(max=-1e-8), x.clamp(min=1e-8)) # pyre-fixme[53]: Captured variable `n` is not annotated. # pyre-fixme[53]: Captured variable `p0` is not annotated. # pyre-fixme[53]: Captured variable `px` is not annotated. # pyre-fixme[53]: Captured variable `py` is not annotated. # pyre-fixme[53]: Captured variable `v02` is not annotated. # pyre-fixme[53]: Captured variable `v10` is not annotated. def check_if_point_inside_triangle(offset_x: int, offset_y: int) -> th.Tensor: _px = px + offset_x _py = py + offset_y vp0p = th.stack([p0[..., 0] - _px, p0[..., 1] - _py], dim=-1) / epsclamp( n[..., 2:3] ) bary_1 = v02[..., 0] * -vp0p[..., 1] + v02[..., 1] * vp0p[..., 0] bary_2 = v10[..., 0] * -vp0p[..., 1] + v10[..., 1] * vp0p[..., 0] return ((bary_1 > 0) & (bary_2 > 0) & ((bary_1 + bary_2) < 1))[:, None] left_pnt_inside_right_triangle = ( check_if_point_inside_triangle(-1, 0)[..., :, 1:] & x_mask & x_both_triangles ) right_pnt_inside_left_triangle = ( check_if_point_inside_triangle(1, 0)[..., :, :-1] & x_mask & x_both_triangles ) down_pnt_inside_up_triangle = ( check_if_point_inside_triangle(0, 1)[..., :-1, :] & y_mask & y_both_triangles ) up_pnt_inside_down_triangle = ( check_if_point_inside_triangle(0, -1)[..., 1:, :] & y_mask & y_both_triangles ) horizontal_intersection = ( right_pnt_inside_left_triangle & left_pnt_inside_right_triangle ) vertical_intersection = ( down_pnt_inside_up_triangle & up_pnt_inside_down_triangle ) left_hangs_over_right = left_pnt_inside_right_triangle & ( ~right_pnt_inside_left_triangle ) right_hangs_over_left = right_pnt_inside_left_triangle & ( ~left_pnt_inside_right_triangle ) up_hangs_over_down = up_pnt_inside_down_triangle & ( ~down_pnt_inside_up_triangle ) down_hangs_over_up = down_pnt_inside_up_triangle & ( ~up_pnt_inside_down_triangle ) x_grad *= x_mask y_grad *= y_mask grad_output_x = 0.5 * (grad_output[:, :, :, 1:] + grad_output[:, :, :, :-1]) grad_output_y = 0.5 * (grad_output[:, :, 1:, :] + grad_output[:, :, :-1, :]) x_grad = (x_grad * grad_output_x).sum(dim=1) y_grad = (y_grad * grad_output_y).sum(dim=1) x_grad_no_int = x_grad * (~horizontal_intersection[:, 0]) y_grad_no_int = y_grad * (~vertical_intersection[:, 0]) x_grad_spread = th.zeros( *x_grad_no_int.shape[:1], x_grad_no_int.shape[1], y_grad_no_int.shape[2], dtype=x_grad_no_int.dtype, device=x_grad_no_int.device, ) x_grad_spread[:, :, :-1] = x_grad_no_int * (~right_hangs_over_left[:, 0]) x_grad_spread[:, :, 1:] += x_grad_no_int * (~left_hangs_over_right[:, 0]) y_grad_spread = th.zeros_like(x_grad_spread) y_grad_spread[:, :-1, :] = y_grad_no_int * (~down_hangs_over_up[:, 0]) y_grad_spread[:, 1:, :] += y_grad_no_int * (~up_hangs_over_down[:, 0]) # Intersections. Compute border sliding gradients ################################################# z_grad_spread = th.zeros_like(x_grad_spread) x_grad_int = x_grad * horizontal_intersection[:, 0] y_grad_int = y_grad * vertical_intersection[:, 0] n = thf.normalize(n, dim=-1) n = n.permute(0, 3, 1, 2) n_left = n[..., :, :-1] n_right = n[..., :, 1:] n_up = n[..., :-1, :] n_down = n[..., 1:, :] def get_dp_db(v_varying: th.Tensor, v_fixed: th.Tensor) -> th.Tensor: """ Computes derivative of the point position with respect to edge displacement See drtk/src/edge_grad/edge_grad_kernel.cu Please refer to the paper "Rasterized Edge Gradients: Handling Discontinuities Differentiably" for details. """ v_varying = thf.normalize(v_varying, dim=1) v_fixed = thf.normalize(v_fixed, dim=1) b = th.stack([-v_fixed[:, 1], v_fixed[:, 0]], dim=1) b_dot_varying = (b * v_varying).sum(dim=1, keepdim=True) return b[:, 0:1] / epsclamp(b_dot_varying) * v_varying # We compute partial derivatives by fixing one triangle and moving the # other, and then vice versa. # Left triangle moves, right fixed dp_dbx = get_dp_db(n_left[:, [0, 2]], -n_right[:, [0, 2]]) x_grad_spread[:, :, :-1] += x_grad_int * dp_dbx[:, 0] z_grad_spread[:, :, :-1] += x_grad_int * dp_dbx[:, 1] # Left triangle fixed, right moves dp_dbx = get_dp_db(n_right[:, [0, 2]], n_left[:, [0, 2]]) x_grad_spread[:, :, 1:] += x_grad_int * dp_dbx[:, 0] z_grad_spread[:, :, 1:] += x_grad_int * dp_dbx[:, 1] # Upper triangle moves, lower fixed dp_dby = get_dp_db(n_up[:, [1, 2]], -n_down[:, [1, 2]]) y_grad_spread[:, :-1, :] += y_grad_int * dp_dby[:, 0] z_grad_spread[:, :-1, :] += y_grad_int * dp_dby[:, 1] # Lower triangle moves, upper fixed dp_dby = get_dp_db(n_down[:, [1, 2]], n_up[:, [1, 2]]) y_grad_spread[:, 1:, :] += y_grad_int * dp_dby[:, 0] z_grad_spread[:, 1:, :] += y_grad_int * dp_dby[:, 1] m = index_img == -1 x_grad_spread[m] = 0.0 y_grad_spread[m] = 0.0 grad_v_pix = -th.stack([x_grad_spread, y_grad_spread, z_grad_spread], dim=3) return None, grad_v_pix, None, grad_output, None