Commit f21fb818 authored by Ruilong Li's avatar Ruilong Li
Browse files

readme

parent 31422712
# nerfacc
This is a tiny tootlbox for fast NeRF rendering.
## Instant-NGP example
```
python examples/trainval.py
```
## Performance Reference
Tested with the default settings on the Lego dataset.
Here the speed refers to the `iterations per second`.
| Model | Split | PSNR | Train Speed | Test Speed | GPU |
| - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 97 | 7.8 | V100 |
| ours | train (30K steps) | 34.26 | 96 | ? | TITAN RTX |
\ No newline at end of file
...@@ -12,13 +12,12 @@ from radiance_fields.ngp import NGPradianceField ...@@ -12,13 +12,12 @@ from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering from nerfacc import OccupancyField, volumetric_rendering
def render_image(radiance_field, rays, render_bkgd, chunk=8192): def render_image(radiance_field, rays, render_bkgd):
"""Render the pixels of an image. """Render the pixels of an image.
Args: Args:
radiance_field: the radiance field of nerf. radiance_field: the radiance field of nerf.
rays: a `Rays` namedtuple, the rays to be rendered. rays: a `Rays` namedtuple, the rays to be rendered.
chunk: int, the size of chunks to render sequentially.
Returns: Returns:
rgb: torch.tensor, rendered color image. rgb: torch.tensor, rendered color image.
...@@ -33,9 +32,10 @@ def render_image(radiance_field, rays, render_bkgd, chunk=8192): ...@@ -33,9 +32,10 @@ def render_image(radiance_field, rays, render_bkgd, chunk=8192):
else: else:
num_rays, _ = rays_shape num_rays, _ = rays_shape
results = [] results = []
chunk = torch.iinfo(torch.int32).max if radiance_field.training else 8192
for i in range(0, num_rays, chunk): for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_color, chunk_depth, chunk_weight, _, = volumetric_rendering( chunk_color, chunk_depth, chunk_weight, alive_ray_mask, = volumetric_rendering(
query_fn=radiance_field.forward, # {x, dir} -> {rgb, density} query_fn=radiance_field.forward, # {x, dir} -> {rgb, density}
rays_o=chunk_rays.origins, rays_o=chunk_rays.origins,
rays_d=chunk_rays.viewdirs, rays_d=chunk_rays.viewdirs,
...@@ -45,12 +45,13 @@ def render_image(radiance_field, rays, render_bkgd, chunk=8192): ...@@ -45,12 +45,13 @@ def render_image(radiance_field, rays, render_bkgd, chunk=8192):
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
render_n_samples=render_n_samples, render_n_samples=render_n_samples,
) )
results.append([chunk_color, chunk_depth, chunk_weight]) results.append([chunk_color, chunk_depth, chunk_weight, alive_ray_mask])
rgb, depth, acc = [torch.cat(r, dim=0) for r in zip(*results)] rgb, depth, acc, alive_ray_mask = [torch.cat(r, dim=0) for r in zip(*results)]
return ( return (
rgb.view((*rays_shape[:-1], -1)), rgb.view((*rays_shape[:-1], -1)),
depth.view((*rays_shape[:-1], -1)), depth.view((*rays_shape[:-1], -1)),
acc.view((*rays_shape[:-1], -1)), acc.view((*rays_shape[:-1], -1)),
alive_ray_mask.view(*rays_shape[:-1]),
) )
...@@ -136,7 +137,9 @@ if __name__ == "__main__": ...@@ -136,7 +137,9 @@ if __name__ == "__main__":
# update occupancy grid # update occupancy grid
occ_field.every_n_step(step) occ_field.every_n_step(step)
rgb, depth, acc = render_image(radiance_field, rays, render_bkgd) rgb, depth, acc, alive_ray_mask = render_image(
radiance_field, rays, render_bkgd
)
# compute loss # compute loss
loss = F.mse_loss(rgb, pixels) loss = F.mse_loss(rgb, pixels)
...@@ -162,7 +165,7 @@ if __name__ == "__main__": ...@@ -162,7 +165,7 @@ if __name__ == "__main__":
pixels = data["pixels"].to(device) pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device) render_bkgd = data["color_bkgd"].to(device)
# rendering # rendering
rgb, depth, acc = render_image( rgb, depth, acc, alive_ray_mask = render_image(
radiance_field, rays, render_bkgd radiance_field, rays, render_bkgd
) )
mse = F.mse_loss(rgb, pixels) mse = F.mse_loss(rgb, pixels)
......
...@@ -15,6 +15,7 @@ def volumetric_rendering( ...@@ -15,6 +15,7 @@ def volumetric_rendering(
scene_resolution: Tuple[int, int, int], scene_resolution: Tuple[int, int, int],
render_bkgd: torch.Tensor = None, render_bkgd: torch.Tensor = None,
render_n_samples: int = 1024, render_n_samples: int = 1024,
render_est_n_samples: int = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""A *fast* version of differentiable volumetric rendering.""" """A *fast* version of differentiable volumetric rendering."""
...@@ -30,7 +31,10 @@ def volumetric_rendering( ...@@ -30,7 +31,10 @@ def volumetric_rendering(
render_bkgd = render_bkgd.contiguous() render_bkgd = render_bkgd.contiguous()
n_rays = rays_o.shape[0] n_rays = rays_o.shape[0]
render_total_samples = n_rays * render_n_samples if render_est_n_samples is None:
render_total_samples = n_rays * render_n_samples
else:
render_total_samples = n_rays * render_est_n_samples
render_step_size = ( render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples (scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment