Commit ba049483 authored by Ruilong Li's avatar Ruilong Li
Browse files

a bit cleaning

parent 5b6f0c61
...@@ -55,9 +55,6 @@ def render_image(radiance_field, rays, render_bkgd, render_step_size): ...@@ -55,9 +55,6 @@ def render_image(radiance_field, rays, render_bkgd, render_step_size):
num_rays, _ = rays_shape num_rays, _ = rays_shape
results = [] results = []
chunk = torch.iinfo(torch.int32).max if radiance_field.training else 81920 chunk = torch.iinfo(torch.int32).max if radiance_field.training else 81920
render_est_n_samples = (
TARGET_SAMPLE_BATCH_SIZE * 16 if radiance_field.training else None
)
for i in range(0, num_rays, chunk): for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_results = volumetric_rendering( chunk_results = volumetric_rendering(
...@@ -68,8 +65,6 @@ def render_image(radiance_field, rays, render_bkgd, render_step_size): ...@@ -68,8 +65,6 @@ def render_image(radiance_field, rays, render_bkgd, render_step_size):
scene_occ_binary=occ_field.occ_grid_binary, scene_occ_binary=occ_field.occ_grid_binary,
scene_resolution=occ_field.resolution, scene_resolution=occ_field.resolution,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
render_n_samples=render_n_samples,
render_est_n_samples=render_est_n_samples, # memory control: wrost case
render_step_size=render_step_size, render_step_size=render_step_size,
) )
results.append(chunk_results) results.append(chunk_results)
......
...@@ -38,6 +38,7 @@ def volumetric_marching( ...@@ -38,6 +38,7 @@ def volumetric_marching(
rays_o: torch.Tensor, rays_o: torch.Tensor,
rays_d: torch.Tensor, rays_d: torch.Tensor,
aabb: torch.Tensor, aabb: torch.Tensor,
scene_resolution: Tuple[int, int, int],
scene_occ_binary: torch.Tensor, scene_occ_binary: torch.Tensor,
t_min: torch.Tensor = None, t_min: torch.Tensor = None,
t_max: torch.Tensor = None, t_max: torch.Tensor = None,
...@@ -52,8 +53,9 @@ def volumetric_marching( ...@@ -52,8 +53,9 @@ def volumetric_marching(
rays_d: Normalized ray directions. Tensor with shape (n_rays, 3). rays_d: Normalized ray directions. Tensor with shape (n_rays, 3).
aabb: Scene bounding box {xmin, ymin, zmin, xmax, ymax, zmax}. aabb: Scene bounding box {xmin, ymin, zmin, xmax, ymax, zmax}.
Tensor with shape (6) Tensor with shape (6)
scene_resolution: Shape of the `scene_occ_binary`. {resx, resy, resz}.
scene_occ_binary: Scene occupancy binary field. BoolTensor with shape scene_occ_binary: Scene occupancy binary field. BoolTensor with shape
(resx, resy, resz) (resx * resy * resz)
t_min: Optional. Ray near planes. Tensor with shape (n_ray,). t_min: Optional. Ray near planes. Tensor with shape (n_ray,).
If not given it will be calculated using aabb test. Default is None. If not given it will be calculated using aabb test. Default is None.
t_max: Optional. Ray far planes. Tensor with shape (n_ray,) t_max: Optional. Ray far planes. Tensor with shape (n_ray,)
...@@ -74,7 +76,10 @@ def volumetric_marching( ...@@ -74,7 +76,10 @@ def volumetric_marching(
raise NotImplementedError("Only support cuda inputs.") raise NotImplementedError("Only support cuda inputs.")
if t_min is None or t_max is None: if t_min is None or t_max is None:
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, aabb) t_min, t_max = ray_aabb_intersect(rays_o, rays_d, aabb)
assert scene_occ_binary.dim() == 3, f"Shape {scene_occ_binary.shape} is not right!" assert (
scene_occ_binary.numel()
== scene_resolution[0] * scene_resolution[1] * scene_resolution[2]
), f"Shape {scene_occ_binary.shape} is not right!"
( (
packed_info, packed_info,
...@@ -90,7 +95,7 @@ def volumetric_marching( ...@@ -90,7 +95,7 @@ def volumetric_marching(
t_max.contiguous(), t_max.contiguous(),
# density grid # density grid
aabb.contiguous(), aabb.contiguous(),
list(scene_occ_binary.shape), scene_resolution,
scene_occ_binary.contiguous(), scene_occ_binary.contiguous(),
# sampling # sampling
render_step_size, render_step_size,
......
import math
from typing import Callable, Tuple from typing import Callable, Tuple
import torch import torch
...@@ -18,29 +17,11 @@ def volumetric_rendering( ...@@ -18,29 +17,11 @@ def volumetric_rendering(
scene_aabb: torch.Tensor, scene_aabb: torch.Tensor,
scene_occ_binary: torch.Tensor, scene_occ_binary: torch.Tensor,
scene_resolution: Tuple[int, int, int], scene_resolution: Tuple[int, int, int],
render_bkgd: torch.Tensor = None, render_bkgd: torch.Tensor,
render_n_samples: int = 1024, render_step_size: int,
render_est_n_samples: int = None,
render_step_size: int = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""A *fast* version of differentiable volumetric rendering.""" """A *fast* version of differentiable volumetric rendering."""
device = rays_o.device
if render_bkgd is None:
render_bkgd = torch.ones(3, device=device)
rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous()
scene_aabb = scene_aabb.contiguous()
scene_occ_binary = scene_occ_binary.contiguous()
render_bkgd = render_bkgd.contiguous()
n_rays = rays_o.shape[0] n_rays = rays_o.shape[0]
if render_step_size is None:
# Note: CPU<->GPU is not idea, try to pre-define it outside this function.
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
)
# get packed samples from ray marching & occupancy check. # get packed samples from ray marching & occupancy check.
with torch.no_grad(): with torch.no_grad():
...@@ -56,7 +37,8 @@ def volumetric_rendering( ...@@ -56,7 +37,8 @@ def volumetric_rendering(
rays_d, rays_d,
# density grid # density grid
aabb=scene_aabb, aabb=scene_aabb,
scene_occ_binary=scene_occ_binary.reshape(scene_resolution), scene_resolution=scene_resolution,
scene_occ_binary=scene_occ_binary,
# sampling # sampling
render_step_size=render_step_size, render_step_size=render_step_size,
) )
...@@ -67,9 +49,7 @@ def volumetric_rendering( ...@@ -67,9 +49,7 @@ def volumetric_rendering(
# compat the samples thru volumetric rendering # compat the samples thru volumetric rendering
with torch.no_grad(): with torch.no_grad():
densities = query_fn( densities = query_fn(frustum_positions, frustum_dirs, only_density=True)
frustum_positions, frustum_dirs, only_density=True, **kwargs
)
( (
compact_packed_info, compact_packed_info,
compact_frustum_starts, compact_frustum_starts,
...@@ -84,18 +64,10 @@ def volumetric_rendering( ...@@ -84,18 +64,10 @@ def volumetric_rendering(
frustum_positions, frustum_positions,
frustum_dirs, frustum_dirs,
) )
# compact_frustum_positions = (
# compact_frustum_origins
# + compact_frustum_dirs
# * (compact_frustum_starts + compact_frustum_ends)
# / 2.0
# )
compact_steps_counter = compact_packed_info[:, -1].sum(0, keepdim=True) compact_steps_counter = compact_packed_info[:, -1].sum(0, keepdim=True)
# network # network
compact_query_results = query_fn( compact_query_results = query_fn(compact_frustum_positions, compact_frustum_dirs)
compact_frustum_positions, compact_frustum_dirs, **kwargs
)
compact_rgbs, compact_densities = compact_query_results[0], compact_query_results[1] compact_rgbs, compact_densities = compact_query_results[0], compact_query_results[1]
# accumulation # accumulation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment