Commit e813bcaa authored by Ruilong Li's avatar Ruilong Li
Browse files

occ field doc

parent 1ed5257d
......@@ -3,33 +3,42 @@ from typing import Callable, List, Tuple, Union
import torch
from torch import nn
from .grid import meshgrid
from ._grid import meshgrid
class OccupancyField(nn.Module):
"""Occupancy Field."""
"""Occupancy Field that supports EMA updates.
It supports both 2D and 3D cases, where in the 2D cases the occupancy field
is basically a segmentation mask.
Args:
occ_eval_fn: A Callable function that takes in the un-normalized points x,
with shape of (N, 2) or (N, 3) (depends on `num_dim`), and outputs
the occupancy of those points with shape of (N, 1).
aabb: Scene bounding box. {min_x, min_y, (min_z), max_x, max_y, (max_z)}.
It can be either a list or a torch.Tensor.
resolution: The field resolution. It can either be a int of a list of ints
to specify resolution on each dimention. {res_x, res_y, (res_z)}. Default
is 128.
num_dim: The space dimension. Either 2 or 3. Default is 3. Note other arguments
should match with the space dimension being set here.
"""
def __init__(
self,
# Shape (N, 3) -> (N, 1). Values are in range [0, 1]: density * step_size
occ_eval_fn: Callable,
aabb: Union[torch.Tensor, List[float]],
resolution: Union[int, List[int]], # cell resolution
resolution: Union[int, List[int]] = 128,
num_dim: int = 3,
) -> None:
# def occ_eval_fn(x):
# step_size = (rays.far - rays.near).max() / self.num_samples
# densities, _ = self.radiance_field.query_density(x)
# occ = densities * step_size
# return occ
super().__init__()
self.occ_eval_fn = occ_eval_fn
if not isinstance(aabb, torch.Tensor):
aabb = torch.tensor(aabb, dtype=torch.float32)
if not isinstance(resolution, (list, tuple)):
resolution = [resolution] * num_dim
assert num_dim in [2, 3], "Currently only supports 2D or 3D field."
assert aabb.shape == (
num_dim * 2,
), f"shape of aabb ({aabb.shape}) should be num_dim * 2 ({num_dim * 2})."
......@@ -45,7 +54,6 @@ class OccupancyField(nn.Module):
# Stores cell occupancy values ranged in [0, 1].
occ_grid = torch.zeros(self.num_cells)
self.register_buffer("occ_grid", occ_grid)
occ_grid_binary = torch.zeros(self.num_cells, dtype=torch.bool)
self.register_buffer("occ_grid_binary", occ_grid_binary)
......@@ -53,23 +61,20 @@ class OccupancyField(nn.Module):
occ_grid_mean = occ_grid.mean()
self.register_buffer("occ_grid_mean", occ_grid_mean)
# Grid coords & indices
grid_coords = meshgrid(self.resolution).reshape(self.num_cells, self.num_dim)
self.register_buffer("grid_coords", grid_coords)
grid_indices = torch.arange(self.num_cells)
self.register_buffer("grid_indices", grid_indices)
@torch.no_grad()
def get_all_cells(
self,
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
def _get_all_cells(self) -> torch.Tensor:
"""Returns all cells of the grid."""
return self.grid_indices
@torch.no_grad()
def sample_uniform_and_occupied_cells(
self, n: int
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Samples both n uniform and occupied cells (per level)."""
def _sample_uniform_and_occupied_cells(self, n: int) -> torch.Tensor:
"""Samples both n uniform and occupied cells."""
device = self.occ_grid.device
uniform_indices = torch.randint(self.num_cells, (n,), device=device)
......@@ -88,16 +93,26 @@ class OccupancyField(nn.Module):
step: int,
occ_threshold: float = 0.01,
ema_decay: float = 0.95,
warmup_steps: int = 256,
) -> None:
"""Update the occ_grid (as well as occ_bitfield) in EMA way."""
"""Update the occ field in the EMA way.
Args:
step: Current training step.
occ_threshold: Threshold to binarize the occupancy field.
ema_decay: The decay rate for EMA updates.
warmup_steps: Sample all cells during the warmup stage. After the warmup
stage we change the sampling strategy to 1/4 unifromly sampled cells
together with 1/4 occupied cells.
"""
resolution = torch.tensor(self.resolution).to(self.occ_grid.device)
# sample cells
if step < 256:
indices = self.get_all_cells()
if step < warmup_steps:
indices = self._get_all_cells()
else:
N = resolution.prod().item() // 4
indices = self.sample_uniform_and_occupied_cells(N)
indices = self._sample_uniform_and_occupied_cells(N)
# infer occupancy: density * step_size
tmp_occ_grid = -torch.ones_like(self.occ_grid)
......@@ -118,8 +133,19 @@ class OccupancyField(nn.Module):
)
@torch.no_grad()
def query_occ(self, x: torch.Tensor) -> torch.Tensor:
"""Query the occ_grid."""
def query_occ(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Query the occupancy, given samples.
Args:
x: Samples with shape (..., 2) or (..., 3).
Returns:
float occupancy values with shape (...),
binary occupancy values with shape (...)
"""
assert (
x.shape[-1] == self.num_dim
), "The samples are not drawn from a proper space!"
resolution = torch.tensor(self.resolution).to(self.occ_grid.device)
bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
......@@ -144,10 +170,21 @@ class OccupancyField(nn.Module):
return occs, occs_binary
@torch.no_grad()
def every_n_step(self, step: int, n: int = 16):
def every_n_step(
self,
step: int,
occ_thre: float = 1e-2,
ema_decay: float = 0.95,
n: int = 16,
):
if not self.training:
raise RuntimeError(
"You should only call this function during training. Please call update() "
"directly if you want to update the field during inference."
)
if step % n == 0 and self.training:
self.update(
step=step,
occ_threshold=0.01,
ema_decay=0.95,
occ_threshold=occ_thre,
ema_decay=ema_decay,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment