Unverified Commit ed327e36 authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Ruilongli/docs (#6)

* build doc should wokr

* prettier doc

* Tensor
parent 00cbf55e
...@@ -33,22 +33,21 @@ def meshgrid3d(res: Tuple[int, int, int], device: torch.device = "cpu"): ...@@ -33,22 +33,21 @@ def meshgrid3d(res: Tuple[int, int, int], device: torch.device = "cpu"):
class OccupancyField(nn.Module): class OccupancyField(nn.Module):
"""Occupancy Field that supports EMA updates. """Occupancy Field that supports EMA updates. Both 2D and 3D are supported.
It supports both 2D and 3D cases, where in the 2D cases the occupancy field Note:
is basically a segmentation mask. Make sure the arguemnts match with the ``num_dim`` -- Either 2D or 3D.
Args: Args:
occ_eval_fn: A Callable function that takes in the un-normalized points x, occ_eval_fn: A Callable function that takes in the un-normalized points x,
with shape of (N, 2) or (N, 3) (depends on `num_dim`), and outputs with shape of (N, 2) or (N, 3) (depends on ``num_dim``),
the occupancy of those points with shape of (N, 1). and outputs the occupancy of those points with shape of (N, 1).
aabb: Scene bounding box. {min_x, min_y, (min_z), max_x, max_y, (max_z)}. aabb: Scene bounding box. If ``num_dim=2`` it should be {min_x, min_y,max_x, max_y}.
It can be either a list or a torch.Tensor. If ``num_dim=3`` it should be {min_x, min_y, min_z, max_x, max_y, max_z}.
resolution: The field resolution. It can either be a int of a list of ints resolution: The field resolution. It can either be a int of a list of ints
to specify resolution on each dimention. {res_x, res_y, (res_z)}. Default to specify resolution on each dimention. If ``num_dim=2`` it is for {res_x, res_y}.
is 128. If ``num_dim=3`` it is for {res_x, res_y, res_z}. Default is 128.
num_dim: The space dimension. Either 2 or 3. Default is 3. Note other arguments num_dim: The space dimension. Either 2 or 3. Default is 3.
should match with the space dimension being set here.
""" """
def __init__( def __init__(
...@@ -114,23 +113,14 @@ class OccupancyField(nn.Module): ...@@ -114,23 +113,14 @@ class OccupancyField(nn.Module):
return indices return indices
@torch.no_grad() @torch.no_grad()
def update( def _update(
self, self,
step: int, step: int,
occ_threshold: float = 0.01, occ_thre: float = 0.01,
ema_decay: float = 0.95, ema_decay: float = 0.95,
warmup_steps: int = 256, warmup_steps: int = 256,
) -> None: ) -> None:
"""Update the occ field in the EMA way. """Update the occ field in the EMA way."""
Args:
step: Current training step.
occ_threshold: Threshold to binarize the occupancy field.
ema_decay: The decay rate for EMA updates.
warmup_steps: Sample all cells during the warmup stage. After the warmup
stage we change the sampling strategy to 1/4 unifromly sampled cells
together with 1/4 occupied cells.
"""
# sample cells # sample cells
if step < warmup_steps: if step < warmup_steps:
indices = self._get_all_cells() indices = self._get_all_cells()
...@@ -157,7 +147,7 @@ class OccupancyField(nn.Module): ...@@ -157,7 +147,7 @@ class OccupancyField(nn.Module):
) )
self.occ_grid_mean = self.occ_grid.mean() self.occ_grid_mean = self.occ_grid.mean()
self.occ_grid_binary = self.occ_grid > torch.clamp( self.occ_grid_binary = self.occ_grid > torch.clamp(
self.occ_grid_mean, max=occ_threshold self.occ_grid_mean, max=occ_thre
) )
@torch.no_grad() @torch.no_grad()
...@@ -168,8 +158,7 @@ class OccupancyField(nn.Module): ...@@ -168,8 +158,7 @@ class OccupancyField(nn.Module):
x: Samples with shape (..., 2) or (..., 3). x: Samples with shape (..., 2) or (..., 3).
Returns: Returns:
float occupancy values with shape (...), float and binary occupancy values with shape (...) respectively.
binary occupancy values with shape (...)
""" """
assert ( assert (
x.shape[-1] == self.num_dim x.shape[-1] == self.num_dim
...@@ -206,11 +195,27 @@ class OccupancyField(nn.Module): ...@@ -206,11 +195,27 @@ class OccupancyField(nn.Module):
warmup_steps: int = 256, warmup_steps: int = 256,
n: int = 16, n: int = 16,
): ):
"""Update the field every n steps during training.""" """Update the field every n steps during training.
This function is designed for training only. If for some reason you want to
manually update the field, please use the ``_update()`` function instead.
Args:
step: Current training step.
occ_thre: Threshold to binarize the occupancy field.
ema_decay: The decay rate for EMA updates.
warmup_steps: Sample all cells during the warmup stage. After the warmup
stage we change the sampling strategy to 1/4 unifromly sampled cells
together with 1/4 occupied cells.
n: Update the field every n steps.
Returns:
None
"""
if not self.training: if not self.training:
raise RuntimeError( raise RuntimeError(
"You should only call this function only during training. " "You should only call this function only during training. "
"Please call update() directly if you want to update the " "Please call _update() directly if you want to update the "
"field during inference." "field during inference."
) )
if step % n == 0 and self.training: if step % n == 0 and self.training:
......
from typing import Tuple from typing import Tuple
import torch import torch
from torch import Tensor
import nerfacc.cuda as nerfacc_cuda import nerfacc.cuda as nerfacc_cuda
@torch.no_grad() @torch.no_grad()
def ray_aabb_intersect( def ray_aabb_intersect(
rays_o: torch.Tensor, rays_d: torch.Tensor, aabb: torch.Tensor rays_o: Tensor, rays_d: Tensor, aabb: Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Ray AABB Test. """Ray AABB Test.
Note: this function is not differentiable to inputs. Note: this function is not differentiable to inputs.
...@@ -16,11 +17,11 @@ def ray_aabb_intersect( ...@@ -16,11 +17,11 @@ def ray_aabb_intersect(
Args: Args:
rays_o: Ray origins. Tensor with shape (n_rays, 3). rays_o: Ray origins. Tensor with shape (n_rays, 3).
rays_d: Normalized ray directions. Tensor with shape (n_rays, 3). rays_d: Normalized ray directions. Tensor with shape (n_rays, 3).
aabb: Scene bounding box {xmin, ymin, zmin, xmax, ymax, zmax}. aabb: Scene bounding box {xmin, ymin, zmin, xmax, ymax, zmax}. \
Tensor with shape (6) Tensor with shape (6)
Returns: Returns:
Ray AABB intersection {t_min, t_max} with shape (n_rays) respectively. Ray AABB intersection {t_min, t_max} with shape (n_rays) respectively. \
Note the t_min is clipped to minimum zero. 1e10 means no intersection. Note the t_min is clipped to minimum zero. 1e10 means no intersection.
""" """
...@@ -36,15 +37,15 @@ def ray_aabb_intersect( ...@@ -36,15 +37,15 @@ def ray_aabb_intersect(
@torch.no_grad() @torch.no_grad()
def volumetric_marching( def volumetric_marching(
rays_o: torch.Tensor, rays_o: Tensor,
rays_d: torch.Tensor, rays_d: Tensor,
aabb: torch.Tensor, aabb: Tensor,
scene_resolution: Tuple[int, int, int], scene_resolution: Tuple[int, int, int],
scene_occ_binary: torch.Tensor, scene_occ_binary: Tensor,
t_min: torch.Tensor = None, t_min: Tensor = None,
t_max: torch.Tensor = None, t_max: Tensor = None,
render_step_size: float = 1e-3, render_step_size: float = 1e-3,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Volumetric marching with occupancy test. """Volumetric marching with occupancy test.
Note: this function is not differentiable to inputs. Note: this function is not differentiable to inputs.
...@@ -52,26 +53,28 @@ def volumetric_marching( ...@@ -52,26 +53,28 @@ def volumetric_marching(
Args: Args:
rays_o: Ray origins. Tensor with shape (n_rays, 3). rays_o: Ray origins. Tensor with shape (n_rays, 3).
rays_d: Normalized ray directions. Tensor with shape (n_rays, 3). rays_d: Normalized ray directions. Tensor with shape (n_rays, 3).
aabb: Scene bounding box {xmin, ymin, zmin, xmax, ymax, zmax}. aabb: Scene bounding box {xmin, ymin, zmin, xmax, ymax, zmax}. \
Tensor with shape (6) Tensor with shape (6)
scene_resolution: Shape of the `scene_occ_binary`. {resx, resy, resz}. scene_resolution: Shape of the `scene_occ_binary`. {resx, resy, resz}.
scene_occ_binary: Scene occupancy binary field. BoolTensor with shape scene_occ_binary: Scene occupancy binary field. BoolTensor with \
(resx * resy * resz) shape (resx * resy * resz)
t_min: Optional. Ray near planes. Tensor with shape (n_ray,). t_min: Optional. Ray near planes. Tensor with shape (n_ray,). \
If not given it will be calculated using aabb test. Default is None. If not given it will be calculated using aabb test. Default is None.
t_max: Optional. Ray far planes. Tensor with shape (n_ray,) t_max: Optional. Ray far planes. Tensor with shape (n_ray,). \
If not given it will be calculated using aabb test. Default is None. If not given it will be calculated using aabb test. Default is None.
render_step_size: Marching step size. Default is 1e-3. render_step_size: Marching step size. Default is 1e-3.
Returns: Returns:
packed_info: Stores infomation on which samples belong to the same ray. A tuple of tensors containing
It is a tensor with shape (n_rays, 2). For each ray, the two values
indicate the start index and the number of samples for this ray, - **packed_info**: Stores infomation on which samples belong to the same ray. \
respectively. It is a tensor with shape (n_rays, 2). For each ray, the two values \
frustum_origins: Sampled frustum origins. Tensor with shape (n_samples, 3). indicate the start index and the number of samples for this ray, \
frustum_dirs: Sampled frustum directions. Tensor with shape (n_samples, 3). respectively.
frustum_starts: Sampled frustum starts. Tensor with shape (n_samples, 1). - **frustum_origins**: Sampled frustum origins. Tensor with shape (n_samples, 3).
frustum_ends: Sampled frustum ends. Tensor with shape (n_samples, 1). - **frustum_dirs**: Sampled frustum directions. Tensor with shape (n_samples, 3).
- **frustum_starts**: Sampled frustum directions. Tensor with shape (n_samples, 3).
- **frustum_ends**: Sampled frustum directions. Tensor with shape (n_samples, 3).
""" """
if not rays_o.is_cuda: if not rays_o.is_cuda:
...@@ -114,34 +117,36 @@ def volumetric_marching( ...@@ -114,34 +117,36 @@ def volumetric_marching(
@torch.no_grad() @torch.no_grad()
def volumetric_rendering_steps( def volumetric_rendering_steps(
packed_info: torch.Tensor, packed_info: Tensor,
sigmas: torch.Tensor, sigmas: Tensor,
frustum_starts: torch.Tensor, frustum_starts: Tensor,
frustum_ends: torch.Tensor, frustum_ends: Tensor,
*args, *args,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute rendering marching steps. """Compute rendering marching steps.
This function will compact the samples by terminate the marching once the This function will compact the samples by terminate the marching once the \
transmittance reaches to 0.9999. It is recommanded that before running your transmittance reaches to 0.9999. It is recommanded that before running your \
network with gradients enabled, first run this function without gradients network with gradients enabled, first run this function without gradients \
(torch.no_grad()) to quickly filter out some samples. (torch.no_grad()) to quickly filter out some samples.
Note: this function is not differentiable to inputs. Note: this function is not differentiable to inputs.
Args: Args:
packed_info: Stores infomation on which samples belong to the same ray. packed_info: Stores infomation on which samples belong to the same ray. \
See volumetric_marching for details. Tensor with shape (n_rays, 2). See volumetric_marching for details. Tensor with shape (n_rays, 2). \
sigmas: Densities at those samples. Tensor with shape (n_samples, 1). sigmas: Densities at those samples. Tensor with shape (n_samples, 1).
frustum_starts: Where the frustum-shape sample starts along a ray. Tensor with frustum_starts: Where the frustum-shape sample starts along a ray. Tensor with \
shape (n_samples, 1). shape (n_samples, 1).
frustum_ends: Where the frustum-shape sample ends along a ray. Tensor with frustum_ends: Where the frustum-shape sample ends along a ray. Tensor with \
shape (n_samples, 1). shape (n_samples, 1).
Returns: Returns:
compact_packed_info: Compacted version of input packed_info. A tuple of tensors containing
compact_frustum_starts: Compacted version of input frustum_starts.
compact_frustum_ends: Compacted version of input frustum_ends. **compact_packed_info**: Compacted version of input packed_info.
**compact_frustum_starts**: Compacted version of input frustum_starts.
**compact_frustum_ends**: Compacted version of input frustum_ends.
""" """
if ( if (
...@@ -171,28 +176,29 @@ def volumetric_rendering_steps( ...@@ -171,28 +176,29 @@ def volumetric_rendering_steps(
def volumetric_rendering_weights( def volumetric_rendering_weights(
packed_info: torch.Tensor, packed_info: Tensor,
sigmas: torch.Tensor, sigmas: Tensor,
frustum_starts: torch.Tensor, frustum_starts: Tensor,
frustum_ends: torch.Tensor, frustum_ends: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute weights for volumetric rendering. """Compute weights for volumetric rendering.
Note: this function is only differentiable to `sigmas`. Note: this function is only differentiable to `sigmas`.
Args: Args:
packed_info: Stores infomation on which samples belong to the same ray. packed_info: Stores infomation on which samples belong to the same ray. \
See volumetric_marching for details. Tensor with shape (n_rays, 2). See ``volumetric_marching`` for details. Tensor with shape (n_rays, 2).
sigmas: Densities at those samples. Tensor with shape (n_samples, 1). sigmas: Densities at those samples. Tensor with shape (n_samples, 1).
frustum_starts: Where the frustum-shape sample starts along a ray. Tensor with frustum_starts: Where the frustum-shape sample starts along a ray. Tensor with \
shape (n_samples, 1). shape (n_samples, 1).
frustum_ends: Where the frustum-shape sample ends along a ray. Tensor with frustum_ends: Where the frustum-shape sample ends along a ray. Tensor with \
shape (n_samples, 1). shape (n_samples, 1).
Returns: Returns:
weights: Volumetric rendering weights for those samples. Tensor with shape A tuple of tensors containing
(n_samples).
ray_indices: Ray index of each sample. IntTensor with shape (n_sample). - **weights**: Volumetric rendering weights for those samples. Tensor with shape (n_samples).
- **ray_indices**: Ray index of each sample. IntTensor with shape (n_sample).
""" """
if ( if (
...@@ -214,28 +220,28 @@ def volumetric_rendering_weights( ...@@ -214,28 +220,28 @@ def volumetric_rendering_weights(
def volumetric_rendering_accumulate( def volumetric_rendering_accumulate(
weights: torch.Tensor, weights: Tensor,
ray_indices: torch.Tensor, ray_indices: Tensor,
values: torch.Tensor = None, values: Tensor = None,
n_rays: int = None, n_rays: int = None,
) -> torch.Tensor: ) -> Tensor:
"""Accumulate volumetric values along the ray. """Accumulate volumetric values along the ray.
Note: this function is only differentiable to weights and values. Note: this function is only differentiable to weights and values.
Args: Args:
weights: Volumetric rendering weights for those samples. Tensor with shape weights: Volumetric rendering weights for those samples. Tensor with shape \
(n_samples). (n_samples).
ray_indices: Ray index of each sample. IntTensor with shape (n_sample). ray_indices: Ray index of each sample. IntTensor with shape (n_sample).
values: The values to be accmulated. Tensor with shape (n_samples, D). If values: The values to be accmulated. Tensor with shape (n_samples, D). If \
None, the accumulated values are just weights. Default is None. None, the accumulated values are just weights. Default is None.
n_rays: Total number of rays. This will decide the shape of the ouputs. If n_rays: Total number of rays. This will decide the shape of the ouputs. If \
None, it will be inferred from `ray_indices.max() + 1`. If specified None, it will be inferred from `ray_indices.max() + 1`. If specified \
it should be at least larger than `ray_indices.max()`. Default is None. it should be at least larger than `ray_indices.max()`. Default is None.
Returns: Returns:
Accumulated values with shape (n_rays, D). If `values` is not given then Accumulated values with shape (n_rays, D). If `values` is not given then we return \
we return the accumulated weights, in which case D == 1. the accumulated weights, in which case D == 1.
""" """
assert ray_indices.dim() == 1 and weights.dim() == 1 assert ray_indices.dim() == 1 and weights.dim() == 1
if not weights.is_cuda: if not weights.is_cuda:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment