"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "e417035f5d473b9f85d15ba01267d48d7f30e71e"
Unverified Commit 7995c54a authored by Matthew Tancik's avatar Matthew Tancik Committed by GitHub
Browse files

Fix typing errors (#16)

* Fix typing errors

* Add aabb
parent cc700f5b
...@@ -6,15 +6,16 @@ from torch import nn ...@@ -6,15 +6,16 @@ from torch import nn
# from torch_scatter import scatter_max # from torch_scatter import scatter_max
def meshgrid3d(res: Tuple[int, int, int], device: torch.device = "cpu"): def meshgrid3d(res: List[int], device: Union[torch.device, str] = "cpu") -> torch.Tensor:
"""Create 3D grid coordinates. """Create 3D grid coordinates.
Args: Args:
res (Tuple[int, int, int]): resolutions for {x, y, z} dimensions. res: resolutions for {x, y, z} dimensions.
Returns: Returns:
torch.long with shape (res[0], res[1], res[2], 3): dense 3D grid coordinates. torch.long with shape (res[0], res[1], res[2], 3): dense 3D grid coordinates.
""" """
assert len(res) == 3
return ( return (
torch.stack( torch.stack(
torch.meshgrid( torch.meshgrid(
...@@ -48,7 +49,19 @@ class OccupancyField(nn.Module): ...@@ -48,7 +49,19 @@ class OccupancyField(nn.Module):
to specify resolution on each dimention. If ``num_dim=2`` it is for {res_x, res_y}. to specify resolution on each dimention. If ``num_dim=2`` it is for {res_x, res_y}.
If ``num_dim=3`` it is for {res_x, res_y, res_z}. Default is 128. If ``num_dim=3`` it is for {res_x, res_y, res_z}. Default is 128.
num_dim: The space dimension. Either 2 or 3. Default is 3. num_dim: The space dimension. Either 2 or 3. Default is 3.
Attributes:
aabb: Scene bounding box.
occ_grid: The occupancy grid. It is a tensor of shape (num_cells,).
occ_grid_binary: The binary occupancy grid. It is a tensor of shape (num_cells,).
grid_coords: The grid coordinates. It is a tensor of shape (num_cells, num_dim).
grid_indices: The grid indices. It is a tensor of shape (num_cells,).
""" """
aabb = torch.Tensor
occ_grid: torch.Tensor
occ_grid_binary: torch.Tensor
grid_coords: torch.Tensor
grid_indices: torch.Tensor
def __init__( def __init__(
self, self,
...@@ -75,7 +88,7 @@ class OccupancyField(nn.Module): ...@@ -75,7 +88,7 @@ class OccupancyField(nn.Module):
self.resolution = resolution self.resolution = resolution
self.register_buffer("resolution_tensor", torch.tensor(resolution)) self.register_buffer("resolution_tensor", torch.tensor(resolution))
self.num_dim = num_dim self.num_dim = num_dim
self.num_cells = torch.tensor(resolution).prod().item() self.num_cells = int(torch.tensor(resolution).prod().item())
# Stores cell occupancy values ranged in [0, 1]. # Stores cell occupancy values ranged in [0, 1].
occ_grid = torch.zeros(self.num_cells) occ_grid = torch.zeros(self.num_cells)
...@@ -180,9 +193,11 @@ class OccupancyField(nn.Module): ...@@ -180,9 +193,11 @@ class OccupancyField(nn.Module):
+ grid_coords[..., 1] * self.resolution[-1] + grid_coords[..., 1] * self.resolution[-1]
+ grid_coords[..., 2] + grid_coords[..., 2]
) )
else:
raise NotImplementedError("Currently only supports 2D or 3D field.")
occs = torch.zeros(x.shape[:-1], device=x.device) occs = torch.zeros(x.shape[:-1], device=x.device)
occs[selector] = self.occ_grid[grid_indices[selector]] occs[selector] = self.occ_grid[grid_indices[selector]]
occs_binary = torch.zeros(x.shape[:-1], device=x.device, dtype=bool) occs_binary = torch.zeros(x.shape[:-1], device=x.device, dtype=torch.bool)
occs_binary[selector] = self.occ_grid_binary[grid_indices[selector]] occs_binary[selector] = self.occ_grid_binary[grid_indices[selector]]
return occs, occs_binary return occs, occs_binary
......
from typing import Tuple from typing import Tuple, Optional, List
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -40,10 +40,10 @@ def volumetric_marching( ...@@ -40,10 +40,10 @@ def volumetric_marching(
rays_o: Tensor, rays_o: Tensor,
rays_d: Tensor, rays_d: Tensor,
aabb: Tensor, aabb: Tensor,
scene_resolution: Tuple[int, int, int], scene_resolution: List[int],
scene_occ_binary: Tensor, scene_occ_binary: Tensor,
t_min: Tensor = None, t_min: Optional[Tensor] = None,
t_max: Tensor = None, t_max: Optional[Tensor] = None,
render_step_size: float = 1e-3, render_step_size: float = 1e-3,
near_plane: float = 0.0, near_plane: float = 0.0,
stratified: bool = False stratified: bool = False
...@@ -130,7 +130,7 @@ def volumetric_rendering_steps( ...@@ -130,7 +130,7 @@ def volumetric_rendering_steps(
frustum_starts: Tensor, frustum_starts: Tensor,
frustum_ends: Tensor, frustum_ends: Tensor,
*args, *args,
) -> Tuple[Tensor, Tensor, Tensor]: ) -> Tuple[Tensor, ...]:
"""Compute rendering marching steps. """Compute rendering marching steps.
This function will compact the samples by terminate the marching once the \ This function will compact the samples by terminate the marching once the \
...@@ -188,7 +188,7 @@ def volumetric_rendering_weights( ...@@ -188,7 +188,7 @@ def volumetric_rendering_weights(
sigmas: Tensor, sigmas: Tensor,
frustum_starts: Tensor, frustum_starts: Tensor,
frustum_ends: Tensor, frustum_ends: Tensor,
) -> Tuple[Tensor, Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Compute weights for volumetric rendering. """Compute weights for volumetric rendering.
Note: this function is only differentiable to `sigmas`. Note: this function is only differentiable to `sigmas`.
...@@ -230,8 +230,8 @@ def volumetric_rendering_weights( ...@@ -230,8 +230,8 @@ def volumetric_rendering_weights(
def volumetric_rendering_accumulate( def volumetric_rendering_accumulate(
weights: Tensor, weights: Tensor,
ray_indices: Tensor, ray_indices: Tensor,
values: Tensor = None, values: Optional[Tensor] = None,
n_rays: int = None, n_rays: Optional[int] = None,
) -> Tensor: ) -> Tensor:
"""Accumulate volumetric values along the ray. """Accumulate volumetric values along the ray.
...@@ -265,7 +265,7 @@ def volumetric_rendering_accumulate( ...@@ -265,7 +265,7 @@ def volumetric_rendering_accumulate(
return torch.zeros((n_rays, src.shape[-1]), device=weights.device) return torch.zeros((n_rays, src.shape[-1]), device=weights.device)
if n_rays is None: if n_rays is None:
n_rays = ray_indices.max() + 1 n_rays = int(ray_indices.max()) + 1
else: else:
assert n_rays > ray_indices.max() assert n_rays > ray_indices.max()
......
from typing import Callable, Tuple from typing import Callable, Tuple, List
import torch import torch
...@@ -16,12 +16,12 @@ def volumetric_rendering( ...@@ -16,12 +16,12 @@ def volumetric_rendering(
rays_d: torch.Tensor, rays_d: torch.Tensor,
scene_aabb: torch.Tensor, scene_aabb: torch.Tensor,
scene_occ_binary: torch.Tensor, scene_occ_binary: torch.Tensor,
scene_resolution: Tuple[int, int, int], scene_resolution: List[int],
render_bkgd: torch.Tensor, render_bkgd: torch.Tensor,
render_step_size: int, render_step_size: int,
near_plane: float = 0.0, near_plane: float = 0.0,
stratified: bool = False, stratified: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
"""A *fast* version of differentiable volumetric rendering.""" """A *fast* version of differentiable volumetric rendering."""
n_rays = rays_o.shape[0] n_rays = rays_o.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment