Unverified Commit ff7cf01a authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

update docs for ray gradients (#62)

parent f9a6ae6a
...@@ -39,9 +39,9 @@ from torch import Tensor ...@@ -39,9 +39,9 @@ from torch import Tensor
import nerfacc import nerfacc
radiance_field = ... # network: a NeRF model radiance_field = ... # network: a NeRF model
optimizer = ... # network optimizer
rays_o: Tensor = ... # ray origins. (n_rays, 3) rays_o: Tensor = ... # ray origins. (n_rays, 3)
rays_d: Tensor = ... # ray normalized directions. (n_rays, 3) rays_d: Tensor = ... # ray normalized directions. (n_rays, 3)
optimizer = ... # optimizer
def sigma_fn( def sigma_fn(
t_starts: Tensor, t_ends:Tensor, ray_indices: Tensor t_starts: Tensor, t_ends:Tensor, ray_indices: Tensor
...@@ -76,16 +76,17 @@ def rgb_sigma_fn( ...@@ -76,16 +76,17 @@ def rgb_sigma_fn(
# Efficient Raymarching: Skip empty and occluded space, pack samples from all rays. # Efficient Raymarching: Skip empty and occluded space, pack samples from all rays.
# packed_info: (n_rays, 2). t_starts: (n_samples, 1). t_ends: (n_samples, 1). # packed_info: (n_rays, 2). t_starts: (n_samples, 1). t_ends: (n_samples, 1).
packed_info, t_starts, t_ends = nerfacc.ray_marching( with torch.no_grad():
rays_o, rays_d, sigma_fn=sigma_fn, near_plane=0.2, far_plane=1.0, packed_info, t_starts, t_ends = nerfacc.ray_marching(
early_stop_eps=1e-4, alpha_thre=1e-2, rays_o, rays_d, sigma_fn=sigma_fn, near_plane=0.2, far_plane=1.0,
) early_stop_eps=1e-4, alpha_thre=1e-2,
)
# Differentiable Volumetric Rendering. # Differentiable Volumetric Rendering.
# colors: (n_rays, 3). opaicity: (n_rays, 1). depth: (n_rays, 1). # colors: (n_rays, 3). opaicity: (n_rays, 1). depth: (n_rays, 1).
color, opacity, depth = nerfacc.rendering(rgb_sigma_fn, packed_info, t_starts, t_ends) color, opacity, depth = nerfacc.rendering(rgb_sigma_fn, packed_info, t_starts, t_ends)
# Optimize the radience field. # Optimize: Both the network and rays will receive gradients
optimizer.zero_grad() optimizer.zero_grad()
loss = F.mse_loss(color, color_gt) loss = F.mse_loss(color, color_gt)
loss.backward() loss.backward()
......
...@@ -53,9 +53,9 @@ An simple example is like this: ...@@ -53,9 +53,9 @@ An simple example is like this:
import nerfacc import nerfacc
radiance_field = ... # network: a NeRF model radiance_field = ... # network: a NeRF model
optimizer = ... # network optimizer
rays_o: Tensor = ... # ray origins. (n_rays, 3) rays_o: Tensor = ... # ray origins. (n_rays, 3)
rays_d: Tensor = ... # ray normalized directions. (n_rays, 3) rays_d: Tensor = ... # ray normalized directions. (n_rays, 3)
optimizer = ... # optimizer
def sigma_fn( def sigma_fn(
t_starts: Tensor, t_ends:Tensor, ray_indices: Tensor t_starts: Tensor, t_ends:Tensor, ray_indices: Tensor
...@@ -90,16 +90,17 @@ An simple example is like this: ...@@ -90,16 +90,17 @@ An simple example is like this:
# Efficient Raymarching: Skip empty and occluded space, pack samples from all rays. # Efficient Raymarching: Skip empty and occluded space, pack samples from all rays.
# packed_info: (n_rays, 2). t_starts: (n_samples, 1). t_ends: (n_samples, 1). # packed_info: (n_rays, 2). t_starts: (n_samples, 1). t_ends: (n_samples, 1).
packed_info, t_starts, t_ends = nerfacc.ray_marching( with torch.no_grad():
rays_o, rays_d, sigma_fn=sigma_fn, near_plane=0.2, far_plane=1.0, packed_info, t_starts, t_ends = nerfacc.ray_marching(
early_stop_eps=1e-4, alpha_thre=1e-2, rays_o, rays_d, sigma_fn=sigma_fn, near_plane=0.2, far_plane=1.0,
) early_stop_eps=1e-4, alpha_thre=1e-2,
)
# Differentiable Volumetric Rendering. # Differentiable Volumetric Rendering.
# colors: (n_rays, 3). opaicity: (n_rays, 1). depth: (n_rays, 1). # colors: (n_rays, 3). opaicity: (n_rays, 1). depth: (n_rays, 1).
color, opacity, depth = nerfacc.rendering(rgb_sigma_fn, packed_info, t_starts, t_ends) color, opacity, depth = nerfacc.rendering(rgb_sigma_fn, packed_info, t_starts, t_ends)
# Optimize the radience field. # Optimize: Both the network and rays will receive gradients
optimizer.zero_grad() optimizer.zero_grad()
loss = F.mse_loss(color, color_gt) loss = F.mse_loss(color, color_gt)
loss.backward() loss.backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment