Unverified Commit 7995c54a authored by Matthew Tancik's avatar Matthew Tancik Committed by GitHub
Browse files

Fix typing errors (#16)

* Fix typing errors

* Add aabb
parent cc700f5b
......@@ -6,15 +6,16 @@ from torch import nn
# from torch_scatter import scatter_max
def meshgrid3d(res: Tuple[int, int, int], device: torch.device = "cpu"):
def meshgrid3d(res: List[int], device: Union[torch.device, str] = "cpu") -> torch.Tensor:
"""Create 3D grid coordinates.
Args:
res (Tuple[int, int, int]): resolutions for {x, y, z} dimensions.
res: resolutions for {x, y, z} dimensions.
Returns:
torch.long with shape (res[0], res[1], res[2], 3): dense 3D grid coordinates.
"""
assert len(res) == 3
return (
torch.stack(
torch.meshgrid(
......@@ -48,7 +49,19 @@ class OccupancyField(nn.Module):
to specify resolution on each dimention. If ``num_dim=2`` it is for {res_x, res_y}.
If ``num_dim=3`` it is for {res_x, res_y, res_z}. Default is 128.
num_dim: The space dimension. Either 2 or 3. Default is 3.
Attributes:
aabb: Scene bounding box.
occ_grid: The occupancy grid. It is a tensor of shape (num_cells,).
occ_grid_binary: The binary occupancy grid. It is a tensor of shape (num_cells,).
grid_coords: The grid coordinates. It is a tensor of shape (num_cells, num_dim).
grid_indices: The grid indices. It is a tensor of shape (num_cells,).
"""
aabb = torch.Tensor
occ_grid: torch.Tensor
occ_grid_binary: torch.Tensor
grid_coords: torch.Tensor
grid_indices: torch.Tensor
def __init__(
self,
......@@ -75,7 +88,7 @@ class OccupancyField(nn.Module):
self.resolution = resolution
self.register_buffer("resolution_tensor", torch.tensor(resolution))
self.num_dim = num_dim
self.num_cells = torch.tensor(resolution).prod().item()
self.num_cells = int(torch.tensor(resolution).prod().item())
# Stores cell occupancy values ranged in [0, 1].
occ_grid = torch.zeros(self.num_cells)
......@@ -180,9 +193,11 @@ class OccupancyField(nn.Module):
+ grid_coords[..., 1] * self.resolution[-1]
+ grid_coords[..., 2]
)
else:
raise NotImplementedError("Currently only supports 2D or 3D field.")
occs = torch.zeros(x.shape[:-1], device=x.device)
occs[selector] = self.occ_grid[grid_indices[selector]]
occs_binary = torch.zeros(x.shape[:-1], device=x.device, dtype=bool)
occs_binary = torch.zeros(x.shape[:-1], device=x.device, dtype=torch.bool)
occs_binary[selector] = self.occ_grid_binary[grid_indices[selector]]
return occs, occs_binary
......
from typing import Tuple
from typing import Tuple, Optional, List
import torch
from torch import Tensor
......@@ -40,10 +40,10 @@ def volumetric_marching(
rays_o: Tensor,
rays_d: Tensor,
aabb: Tensor,
scene_resolution: Tuple[int, int, int],
scene_resolution: List[int],
scene_occ_binary: Tensor,
t_min: Tensor = None,
t_max: Tensor = None,
t_min: Optional[Tensor] = None,
t_max: Optional[Tensor] = None,
render_step_size: float = 1e-3,
near_plane: float = 0.0,
stratified: bool = False
......@@ -130,7 +130,7 @@ def volumetric_rendering_steps(
frustum_starts: Tensor,
frustum_ends: Tensor,
*args,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> Tuple[Tensor, ...]:
"""Compute rendering marching steps.
This function will compact the samples by terminate the marching once the \
......@@ -188,7 +188,7 @@ def volumetric_rendering_weights(
sigmas: Tensor,
frustum_starts: Tensor,
frustum_ends: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> Tuple[Tensor, Tensor]:
"""Compute weights for volumetric rendering.
Note: this function is only differentiable to `sigmas`.
......@@ -230,8 +230,8 @@ def volumetric_rendering_weights(
def volumetric_rendering_accumulate(
weights: Tensor,
ray_indices: Tensor,
values: Tensor = None,
n_rays: int = None,
values: Optional[Tensor] = None,
n_rays: Optional[int] = None,
) -> Tensor:
"""Accumulate volumetric values along the ray.
......@@ -265,7 +265,7 @@ def volumetric_rendering_accumulate(
return torch.zeros((n_rays, src.shape[-1]), device=weights.device)
if n_rays is None:
n_rays = ray_indices.max() + 1
n_rays = int(ray_indices.max()) + 1
else:
assert n_rays > ray_indices.max()
......
from typing import Callable, Tuple
from typing import Callable, Tuple, List
import torch
......@@ -16,12 +16,12 @@ def volumetric_rendering(
rays_d: torch.Tensor,
scene_aabb: torch.Tensor,
scene_occ_binary: torch.Tensor,
scene_resolution: Tuple[int, int, int],
scene_resolution: List[int],
render_bkgd: torch.Tensor,
render_step_size: int,
near_plane: float = 0.0,
stratified: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
"""A *fast* version of differentiable volumetric rendering."""
n_rays = rays_o.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment