Commit 5009fc12 authored by Ruilong Li's avatar Ruilong Li
Browse files

add occ field

parent 298ffd02
import math from .occupancy_field import OccupancyField
from typing import Callable, Tuple from .volumetric_rendering import volumetric_rendering
import torch
from .cuda import VolumeRenderer, ray_aabb_intersect, ray_marching
def volumetric_rendering(
query_fn: Callable,
rays_o: torch.Tensor,
rays_d: torch.Tensor,
scene_aabb: torch.Tensor,
scene_occ_binary: torch.Tensor,
scene_resolution: Tuple[int, int, int],
render_bkgd: torch.Tensor = None,
render_n_samples: int = 1024,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""A *fast* version of differentiable volumetric rendering."""
device = rays_o.device
if render_bkgd is None:
render_bkgd = torch.ones(3, device=device)
scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous()
scene_aabb = scene_aabb.contiguous()
scene_occ_binary = scene_occ_binary.contiguous()
render_bkgd = render_bkgd.contiguous()
n_rays = rays_o.shape[0]
render_total_samples = n_rays * render_n_samples
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
)
with torch.no_grad():
# TODO: avoid clamp here. kinda stupid
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb)
t_min = torch.clamp(t_min, max=1e10)
t_max = torch.clamp(t_max, max=1e10)
(
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
) = ray_marching(
# rays
rays_o,
rays_d,
t_min,
t_max,
# density grid
scene_aabb,
scene_resolution,
scene_occ_binary,
# sampling
render_total_samples,
render_n_samples,
render_step_size,
)
# squeeze valid samples
total_samples = max(packed_info[:, -1].sum(), 1)
frustum_origins = frustum_origins[:total_samples]
frustum_dirs = frustum_dirs[:total_samples]
frustum_starts = frustum_starts[:total_samples]
frustum_ends = frustum_ends[:total_samples]
frustum_positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
rgbs, densities = query_results[0], query_results[1]
(
accumulated_weight,
accumulated_depth,
accumulated_color,
alive_ray_mask,
) = VolumeRenderer.apply(
packed_info,
frustum_starts,
frustum_ends,
densities.contiguous(),
rgbs.contiguous(),
)
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
return accumulated_color, accumulated_depth, accumulated_weight, alive_ray_mask
from typing import List, Tuple
import torch
import torch.nn.functional as F
def query_grid(x: torch.Tensor, aabb: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
"""Query values in the grid field given the coordinates.
Args:
x: 2D / 3D coordinates, with shape of [..., 2 or 3]
aabb: 2D / 3D bounding box of the field, with shape of [4 or 6]
grid: Grid with shape [res_x, res_y, res_z, D] or [res_x, res_y, D]
Returns:
values with shape [..., D]
"""
output_shape = list(x.shape[:-1]) + [grid.shape[-1]]
if x.shape[-1] == 2 and aabb.shape == (4,) and grid.ndim == 3:
# 2D case
grid = grid.permute(2, 1, 0).unsqueeze(0) # [1, D, res_y, res_x]
x = (x.view(1, -1, 1, 2) - aabb[0:2]) / (aabb[2:4] - aabb[0:2])
elif x.shape[-1] == 3 and aabb.shape == (6,) and grid.ndim == 4:
# 3D case
grid = grid.permute(3, 2, 1, 0).unsqueeze(0) # [1, D, res_z, res_y, res_x]
x = (x.view(1, -1, 1, 1, 3) - aabb[0:3]) / (aabb[3:6] - aabb[0:3])
else:
raise ValueError(
"The shapes of the inputs do not match to either 2D case or 3D case! "
f"Got x: {x.shape}; aabb: {aabb.shape}; grid: {grid.shape}."
)
v = F.grid_sample(
grid,
x * 2.0 - 1.0,
align_corners=True,
padding_mode="zeros",
)
v = v.reshape(output_shape)
return v
def meshgrid(resolution: List[int]):
if len(resolution) == 2:
return meshgrid2d(resolution)
elif len(resolution) == 3:
return meshgrid3d(resolution)
else:
raise ValueError(resolution)
def meshgrid2d(res: Tuple[int, int], device: torch.device = "cpu"):
"""Create 2D grid coordinates.
Args:
res (Tuple[int, int]): resolutions for {x, y} dimensions.
Returns:
torch.long with shape (res[0], res[1], 2): dense 2D grid coordinates.
"""
return (
torch.stack(
torch.meshgrid(
[
torch.arange(res[0]),
torch.arange(res[1]),
],
indexing="ij",
),
dim=-1,
)
.long()
.to(device)
)
def meshgrid3d(res: Tuple[int, int, int], device: torch.device = "cpu"):
"""Create 3D grid coordinates.
Args:
res (Tuple[int, int, int]): resolutions for {x, y, z} dimensions.
Returns:
torch.long with shape (res[0], res[1], res[2], 3): dense 3D grid coordinates.
"""
return (
torch.stack(
torch.meshgrid(
[
torch.arange(res[0]),
torch.arange(res[1]),
torch.arange(res[2]),
],
indexing="ij",
),
dim=-1,
)
.long()
.to(device)
)
from typing import Callable, List, Tuple, Union
import torch
from torch import nn
from .grid import meshgrid
class OccupancyField(nn.Module):
"""Occupancy Field."""
def __init__(
self,
# Shape (N, 3) -> (N, 1). Values are in range [0, 1]: density * step_size
occ_eval_fn: Callable,
aabb: Union[torch.Tensor, List[float]],
resolution: Union[int, List[int]], # cell resolution
num_dim: int = 3,
) -> None:
# def occ_eval_fn(x):
# step_size = (rays.far - rays.near).max() / self.num_samples
# densities, _ = self.radiance_field.query_density(x)
# occ = densities * step_size
# return occ
super().__init__()
self.occ_eval_fn = occ_eval_fn
if not isinstance(aabb, torch.Tensor):
aabb = torch.tensor(aabb, dtype=torch.float32)
if not isinstance(resolution, (list, tuple)):
resolution = [resolution] * num_dim
assert aabb.shape == (
num_dim * 2,
), f"shape of aabb ({aabb.shape}) should be num_dim * 2 ({num_dim * 2})."
assert (
len(resolution) == num_dim
), f"length of resolution ({len(resolution)}) should be num_dim ({num_dim})."
self.register_buffer("aabb", aabb)
self.resolution = resolution
self.num_dim = num_dim
self.num_cells = torch.tensor(resolution).prod().item()
# Stores cell occupancy values ranged in [0, 1].
occ_grid = torch.zeros(self.num_cells)
self.register_buffer("occ_grid", occ_grid)
occ_grid_binary = torch.zeros(self.num_cells, dtype=torch.bool)
self.register_buffer("occ_grid_binary", occ_grid_binary)
# Used for thresholding occ_grid
occ_grid_mean = occ_grid.mean()
self.register_buffer("occ_grid_mean", occ_grid_mean)
grid_coords = meshgrid(self.resolution).reshape(self.num_cells, self.num_dim)
self.register_buffer("grid_coords", grid_coords)
grid_indices = torch.arange(self.num_cells)
self.register_buffer("grid_indices", grid_indices)
@torch.no_grad()
def get_all_cells(
self,
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Returns all cells of the grid."""
return self.grid_indices
@torch.no_grad()
def sample_uniform_and_occupied_cells(
self, n: int
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Samples both n uniform and occupied cells (per level)."""
device = self.occ_grid.device
uniform_indices = torch.randint(self.num_cells, (n,), device=device)
occupied_indices = torch.nonzero(self.occ_grid_binary)[:, 0]
if n < len(occupied_indices):
selector = torch.randint(len(occupied_indices), (n,), device=device)
occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0)
return indices
@torch.no_grad()
def update(
self,
step: int,
occ_threshold: float = 0.01,
ema_decay: float = 0.95,
) -> None:
"""Update the occ_grid (as well as occ_bitfield) in EMA way."""
resolution = torch.tensor(self.resolution).to(self.occ_grid.device)
# sample cells
if step < 256:
indices = self.get_all_cells()
else:
N = resolution.prod().item() // 4
indices = self.sample_uniform_and_occupied_cells(N)
# infer occupancy: density * step_size
tmp_occ_grid = -torch.ones_like(self.occ_grid)
grid_coords = self.grid_coords[indices]
x = (grid_coords + torch.rand_like(grid_coords.float())) / resolution
bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
x = x * (bb_max - bb_min) + bb_min
tmp_occ_grid[indices] = self.occ_eval_fn(x).squeeze(-1)
# ema update
ema_mask = (self.occ_grid >= 0) & (tmp_occ_grid >= 0)
self.occ_grid[ema_mask] = torch.maximum(
self.occ_grid[ema_mask] * ema_decay, tmp_occ_grid[ema_mask]
)
self.occ_grid_mean = self.occ_grid.mean()
self.occ_grid_binary = self.occ_grid > min(
self.occ_grid_mean.item(), occ_threshold
)
@torch.no_grad()
def query_occ(self, x: torch.Tensor) -> torch.Tensor:
"""Query the occ_grid."""
resolution = torch.tensor(self.resolution).to(self.occ_grid.device)
bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
x = (x - bb_min) / (bb_max - bb_min)
selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)
grid_coords = torch.floor(x * resolution).long()
if self.num_dim == 2:
grid_indices = (
grid_coords[..., 0] * self.resolution[-1] + grid_coords[..., 1]
)
elif self.num_dim == 3:
grid_indices = (
grid_coords[..., 0] * self.resolution[-1] * self.resolution[-2]
+ grid_coords[..., 1] * self.resolution[-1]
+ grid_coords[..., 2]
)
occs = torch.zeros(x.shape[:-1], device=x.device)
occs[selector] = self.occ_grid[grid_indices[selector]]
occs_binary = torch.zeros(x.shape[:-1], device=x.device, dtype=bool)
occs_binary[selector] = self.occ_grid_binary[grid_indices[selector]]
return occs, occs_binary
@torch.no_grad()
def every_n_step(self, step: int, n: int = 16):
if step % n == 0 and self.training:
self.update(
step=step,
occ_threshold=0.01,
ema_decay=0.95,
)
import math
from typing import Callable, Tuple
import torch
from .cuda import VolumeRenderer, ray_aabb_intersect, ray_marching
def volumetric_rendering(
query_fn: Callable,
rays_o: torch.Tensor,
rays_d: torch.Tensor,
scene_aabb: torch.Tensor,
scene_occ_binary: torch.Tensor,
scene_resolution: Tuple[int, int, int],
render_bkgd: torch.Tensor = None,
render_n_samples: int = 1024,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""A *fast* version of differentiable volumetric rendering."""
device = rays_o.device
if render_bkgd is None:
render_bkgd = torch.ones(3, device=device)
scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous()
scene_aabb = scene_aabb.contiguous()
scene_occ_binary = scene_occ_binary.contiguous()
render_bkgd = render_bkgd.contiguous()
n_rays = rays_o.shape[0]
render_total_samples = n_rays * render_n_samples
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
)
with torch.no_grad():
# TODO: avoid clamp here. kinda stupid
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb)
t_min = torch.clamp(t_min, max=1e10)
t_max = torch.clamp(t_max, max=1e10)
(
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
) = ray_marching(
# rays
rays_o,
rays_d,
t_min,
t_max,
# density grid
scene_aabb,
scene_resolution,
scene_occ_binary,
# sampling
render_total_samples,
render_n_samples,
render_step_size,
)
# squeeze valid samples
total_samples = max(packed_info[:, -1].sum(), 1)
frustum_origins = frustum_origins[:total_samples]
frustum_dirs = frustum_dirs[:total_samples]
frustum_starts = frustum_starts[:total_samples]
frustum_ends = frustum_ends[:total_samples]
frustum_positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
rgbs, densities = query_results[0], query_results[1]
(
accumulated_weight,
accumulated_depth,
accumulated_color,
alive_ray_mask,
) = VolumeRenderer.apply(
packed_info,
frustum_starts,
frustum_ends,
densities.contiguous(),
rgbs.contiguous(),
)
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
return accumulated_color, accumulated_depth, accumulated_weight, alive_ray_mask
...@@ -3,7 +3,7 @@ from setuptools import find_packages, setup ...@@ -3,7 +3,7 @@ from setuptools import find_packages, setup
setup( setup(
name="nerfacc", name="nerfacc",
description="NeRF accelerated rendering", description="NeRF accelerated rendering",
version="0.0.2", version="0.0.3",
python_requires=">=3.9", python_requires=">=3.9",
packages=find_packages(exclude=("tests*",)), packages=find_packages(exclude=("tests*",)),
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment