Unverified Commit 19c7b89f authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

support stratified sampling (#10)

* support stratified sampling

* bump version
parent 7f9ecf67
...@@ -15,7 +15,7 @@ Ours on TITAN RTX : ...@@ -15,7 +15,7 @@ Ours on TITAN RTX :
| trainval | Lego | Mic | Materials | | trainval | Lego | Mic | Materials |
| - | - | - | - | | - | - | - | - |
| Time | 300s | 272s | 258s | | Time | 300s | 272s | 258s |
| PSNR | 36.28 | 36.16 | 29.76 | | PSNR | 36.61 | 37.45 | 30.15 |
| FPS | 11.49 | 21.48 | 8.86 | | FPS | 11.49 | 21.48 | 8.86 |
Instant-NGP paper (5 min) on 3090: Instant-NGP paper (5 min) on 3090:
......
...@@ -46,6 +46,7 @@ def render_image(radiance_field, rays, render_bkgd, render_step_size): ...@@ -46,6 +46,7 @@ def render_image(radiance_field, rays, render_bkgd, render_step_size):
scene_resolution=occ_field.resolution, scene_resolution=occ_field.resolution,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
render_step_size=render_step_size, render_step_size=render_step_size,
stratified=radiance_field.training,
) )
results.append(chunk_results) results.append(chunk_results)
rgb, depth, acc, counter, compact_counter = [ rgb, depth, acc, counter, compact_counter = [
...@@ -65,7 +66,7 @@ if __name__ == "__main__": ...@@ -65,7 +66,7 @@ if __name__ == "__main__":
torch.manual_seed(42) torch.manual_seed(42)
device = "cuda:0" device = "cuda:0"
scene = "lego" scene = "materials"
# setup dataset # setup dataset
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
......
...@@ -45,7 +45,8 @@ def volumetric_marching( ...@@ -45,7 +45,8 @@ def volumetric_marching(
t_min: Tensor = None, t_min: Tensor = None,
t_max: Tensor = None, t_max: Tensor = None,
render_step_size: float = 1e-3, render_step_size: float = 1e-3,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: stratified: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Volumetric marching with occupancy test. """Volumetric marching with occupancy test.
Note: this function is not differentiable to inputs. Note: this function is not differentiable to inputs.
...@@ -63,6 +64,7 @@ def volumetric_marching( ...@@ -63,6 +64,7 @@ def volumetric_marching(
t_max: Optional. Ray far planes. Tensor with shape (n_ray,). \ t_max: Optional. Ray far planes. Tensor with shape (n_ray,). \
If not given it will be calculated using aabb test. Default is None. If not given it will be calculated using aabb test. Default is None.
render_step_size: Marching step size. Default is 1e-3. render_step_size: Marching step size. Default is 1e-3.
stratified: Whether to use stratified sampling. Default is False.
Returns: Returns:
A tuple of tensors containing A tuple of tensors containing
...@@ -86,6 +88,8 @@ def volumetric_marching( ...@@ -86,6 +88,8 @@ def volumetric_marching(
== scene_resolution[0] * scene_resolution[1] * scene_resolution[2] == scene_resolution[0] * scene_resolution[1] * scene_resolution[2]
), f"Shape {scene_occ_binary.shape} is not right!" ), f"Shape {scene_occ_binary.shape} is not right!"
if stratified:
t_min = t_min + torch.rand_like(t_min) * render_step_size
( (
packed_info, packed_info,
frustum_origins, frustum_origins,
......
...@@ -19,6 +19,7 @@ def volumetric_rendering( ...@@ -19,6 +19,7 @@ def volumetric_rendering(
scene_resolution: Tuple[int, int, int], scene_resolution: Tuple[int, int, int],
render_bkgd: torch.Tensor, render_bkgd: torch.Tensor,
render_step_size: int, render_step_size: int,
stratified: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""A *fast* version of differentiable volumetric rendering.""" """A *fast* version of differentiable volumetric rendering."""
n_rays = rays_o.shape[0] n_rays = rays_o.shape[0]
...@@ -47,6 +48,7 @@ def volumetric_rendering( ...@@ -47,6 +48,7 @@ def volumetric_rendering(
scene_occ_binary=scene_occ_binary, scene_occ_binary=scene_occ_binary,
# sampling # sampling
render_step_size=render_step_size, render_step_size=render_step_size,
stratified=stratified,
) )
frustum_positions = ( frustum_positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0 frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
......
...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" ...@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "nerfacc" name = "nerfacc"
version = "0.0.3" version = "0.0.4"
authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}] authors = [{name = "Ruilong", email = "ruilongli94@gmail.com"}]
license = { text="MIT" } license = { text="MIT" }
requires-python = ">=3.8" requires-python = ">=3.8"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment