Commit 72c3a0eb authored by Darijan Gudelj's avatar Darijan Gudelj Committed by Facebook GitHub Bot
Browse files

raybundle input to ImplicitFunctions -> api unification

Summary: Currently some implicit functions in implicitron take a raybundle, others take ray_points_world. raybundle is what they really need. However, the raybundle is going to become a bit more flexible later, as it will contain different numbers of rays for each camera.

Reviewed By: bottler

Differential Revision: D39173751

fbshipit-source-id: ebc038e426d22e831e67a18ba64655d8a61e1eb9
parent 70dc9c45
......@@ -19,6 +19,7 @@ class ImplicitFunctionBase(ABC, ReplaceableBase):
@abstractmethod
def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
......
......@@ -3,14 +3,15 @@
# implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv
import math
from typing import Tuple
from typing import Optional, Tuple
import torch
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer.implicit import HarmonicEmbedding
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from torch import nn
from .base import ImplicitFunctionBase
from .utils import get_rays_points_world
@registry.register
......@@ -125,14 +126,16 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
# inconsistently.
def forward(
self,
# ray_bundle: RayBundle,
rays_points_world: torch.Tensor, # TODO: unify the APIs
*,
ray_bundle: Optional[RayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None,
fun_viewpool=None,
global_code=None,
**kwargs,
):
# this field only uses point locations
# rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
rays_points_world = get_rays_points_world(ray_bundle, rays_points_world)
if rays_points_world.numel() == 0 or (
self.embed_fn is None and fun_viewpool is None and global_code is None
......@@ -179,4 +182,4 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
x = self.softplus(x)
return x # TODO: unify the APIs
return x
......@@ -129,6 +129,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
......
......@@ -349,6 +349,7 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
......@@ -408,6 +409,7 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
......
......@@ -10,7 +10,9 @@ import torch
import torch.nn.functional as F
from pytorch3d.common.compat import prod
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import RayBundle
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
......@@ -185,3 +187,31 @@ def interpolate_volume(
**kwargs,
)
return out[:, :, :, 0, 0].permute(0, 2, 1)
def get_rays_points_world(
ray_bundle: Optional[RayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Converts the ray_bundle to rays_points_world if rays_points_world is not defined
and raises error if both are defined.
Args:
ray_bundle: A RayBundle object or None
rays_points_world: A torch.Tensor representing ray points converted to
world coordinates
Returns:
A torch.Tensor representing ray points converted to world coordinates
of shape [minibatch x ... x pts_per_ray x 3].
"""
if rays_points_world is not None and ray_bundle is not None:
raise ValueError(
"Cannot define both rays_points_world and ray_bundle,"
+ " one has to be None."
)
if rays_points_world is not None:
return rays_points_world
if ray_bundle is not None:
return ray_bundle_to_ray_points(ray_bundle)
raise ValueError("ray_bundle and rays_points_world cannot both be None")
......@@ -118,7 +118,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
# eval the raymarching function
raymarch_features, _ = implicit_function(
ray_bundle_t,
ray_bundle=ray_bundle_t,
raymarch_features=None,
)
if self.verbose:
......
......@@ -148,7 +148,7 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
)
output = self.raymarcher(
*implicit_functions[0](ray_bundle),
*implicit_functions[0](ray_bundle=ray_bundle),
ray_lengths=ray_bundle.lengths,
density_noise_std=density_noise_std,
)
......
......@@ -101,7 +101,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
object_mask = object_mask.bool()
implicit_function = implicit_functions[0]
implicit_function_gradient = functools.partial(gradient, implicit_function)
implicit_function_gradient = functools.partial(_gradient, implicit_function)
# object_mask: silhouette of the object
batch_size, *spatial_size, _ = ray_bundle.lengths.shape
......@@ -113,7 +113,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
with torch.no_grad(), evaluating(implicit_function):
points, network_object_mask, dists = self.ray_tracer(
sdf=lambda x: implicit_function(x)[
sdf=lambda x: implicit_function(rays_points_world=x)[
:, 0
], # TODO: get rid of this wrapper
cam_loc=cam_loc,
......@@ -125,7 +125,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
depth = dists.reshape(batch_size, num_pixels, 1)
points = (cam_loc + depth * ray_dirs).reshape(-1, 3)
sdf_output = implicit_function(points)[:, 0:1]
sdf_output = implicit_function(rays_points_world=points)[:, 0:1]
# NOTE most of the intermediate variables are flattened for
# no apparent reason (here and in the ray tracer)
ray_dirs = ray_dirs.reshape(-1, 3)
......@@ -157,7 +157,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
points_all = torch.cat([surface_points, eikonal_points], dim=0)
output = implicit_function(surface_points)
output = implicit_function(rays_points_world=surface_points)
surface_sdf_values = output[
:N, 0:1
].detach() # how is it different from sdf_output?
......@@ -181,7 +181,9 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
grad_theta = None
empty_render = differentiable_surface_points.shape[0] == 0
features = implicit_function(differentiable_surface_points)[None, :, 1:]
features = implicit_function(rays_points_world=differentiable_surface_points)[
None, :, 1:
]
normals_full = features.new_zeros(
batch_size, *spatial_size, 3, requires_grad=empty_render
)
......@@ -260,13 +262,13 @@ def _sample_network(
@torch.enable_grad()
def gradient(module, x):
x.requires_grad_(True)
y = module.forward(x)[:, :1]
def _gradient(module, rays_points_world):
rays_points_world.requires_grad_(True)
y = module.forward(rays_points_world=rays_points_world)[:, :1]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
inputs=rays_points_world,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
......
......@@ -44,7 +44,7 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
implicit_function = SRNImplicitFunction()
device = torch.device("cpu")
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle)
rays_densities, rays_colors = implicit_function(ray_bundle=bundle)
out_features = implicit_function.raymarch_function.out_features
self.assertEqual(
rays_densities.shape,
......@@ -62,7 +62,9 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
implicit_function.to(device)
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle, global_code=global_code)
rays_densities, rays_colors = implicit_function(
ray_bundle=bundle, global_code=global_code
)
out_features = implicit_function.hypernet.out_features
self.assertEqual(
rays_densities.shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment