Commit ad2a0079 authored by Ruilong Li's avatar Ruilong Li
Browse files

proposal_nets_require_grads: 7k, 226s, 34.96db, loss 57

parent 6f7f9fb0
...@@ -42,6 +42,7 @@ def render_image( ...@@ -42,6 +42,7 @@ def render_image(
render_bkgd: Optional[torch.Tensor] = None, render_bkgd: Optional[torch.Tensor] = None,
cone_angle: float = 0.0, cone_angle: float = 0.0,
alpha_thre: float = 0.0, alpha_thre: float = 0.0,
proposal_nets_require_grads: bool = True,
# test options # test options
test_chunk_size: int = 8192, test_chunk_size: int = 8192,
): ):
...@@ -94,6 +95,7 @@ def render_image( ...@@ -94,6 +95,7 @@ def render_image(
stratified=radiance_field.training, stratified=radiance_field.training,
cone_angle=cone_angle, cone_angle=cone_angle,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
proposal_nets_require_grads=proposal_nets_require_grads,
) )
rgb, opacity, depth, weights = rendering( rgb, opacity, depth, weights = rendering(
t_starts, t_starts,
...@@ -312,91 +314,102 @@ if __name__ == "__main__": ...@@ -312,91 +314,102 @@ if __name__ == "__main__":
radiance_field.train() radiance_field.train()
proposal_nets.train() proposal_nets.train()
data = train_dataset[i] # @profile
def _train():
render_bkgd = data["color_bkgd"] data = train_dataset[i]
rays = data["rays"]
pixels = data["pixels"] render_bkgd = data["color_bkgd"]
rays = data["rays"]
# render pixels = data["pixels"]
(
rgb, # render
acc, (
depth, rgb,
n_rendering_samples, acc,
proposal_sample_list, depth,
) = render_image( n_rendering_samples,
radiance_field, proposal_sample_list,
proposal_nets, ) = render_image(
rays, radiance_field,
scene_aabb, proposal_nets,
# rendering options rays,
near_plane=near_plane, scene_aabb,
far_plane=far_plane, # rendering options
render_step_size=render_step_size, near_plane=near_plane,
render_bkgd=render_bkgd, far_plane=far_plane,
cone_angle=args.cone_angle, render_step_size=render_step_size,
alpha_thre=min(alpha_thre, alpha_thre * step / 1000), render_bkgd=render_bkgd,
) cone_angle=args.cone_angle,
if n_rendering_samples == 0: alpha_thre=min(alpha_thre, alpha_thre * step / 1000),
continue proposal_nets_require_grads=(step < 100 or step % 16 == 0),
)
# dynamic batch size for rays to keep sample batch size constant. # if n_rendering_samples == 0:
num_rays = len(pixels) # continue
num_rays = int(
num_rays # dynamic batch size for rays to keep sample batch size constant.
* (target_sample_batch_size / float(n_rendering_samples)) num_rays = len(pixels)
) num_rays = int(
train_dataset.update_num_rays(num_rays) num_rays
alive_ray_mask = acc.squeeze(-1) > 0 * (target_sample_batch_size / float(n_rendering_samples))
)
# compute loss train_dataset.update_num_rays(num_rays)
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) alive_ray_mask = acc.squeeze(-1) > 0
( # compute loss
packed_info, loss = F.smooth_l1_loss(
t_starts, rgb[alive_ray_mask], pixels[alive_ray_mask]
t_ends, )
weights,
) = proposal_sample_list[-1] (
loss_interval = 0.0
for (
proposal_packed_info,
proposal_t_starts,
proposal_t_ends,
proposal_weights,
) in proposal_sample_list[:-1]:
proposal_weights_gt = ray_pdf_query(
packed_info, packed_info,
t_starts, t_starts,
t_ends, t_ends,
weights.detach(), weights,
) = proposal_sample_list[-1]
loss_interval = 0.0
for (
proposal_packed_info, proposal_packed_info,
proposal_t_starts, proposal_t_starts,
proposal_t_ends, proposal_t_ends,
).detach() proposal_weights,
) in proposal_sample_list[:-1]:
loss_interval = ( proposal_weights_gt = ray_pdf_query(
torch.clamp(proposal_weights_gt - proposal_weights, min=0) packed_info,
) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps) t_starts,
loss_interval = loss_interval.mean() t_ends,
loss += loss_interval * 1.0 weights.detach(),
proposal_packed_info,
optimizer.zero_grad() proposal_t_starts,
# do not unscale it because we are using Adam. proposal_t_ends,
grad_scaler.scale(loss).backward() ).detach()
optimizer.step()
scheduler.step() loss_interval = (
torch.clamp(
if step % 100 == 0: proposal_weights_gt - proposal_weights, min=0
elapsed_time = time.time() - tic )
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) ) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps)
print( loss_interval = loss_interval.mean()
f"elapsed_time={elapsed_time:.2f}s | step={step} | " loss += loss_interval * 1.0
f"loss={loss:.5f} | loss_interval={loss_interval:.5f} "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | " optimizer.zero_grad()
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" # do not unscale it because we are using Adam.
) grad_scaler.scale(loss).backward()
optimizer.step()
scheduler.step()
if step % 100 == 0:
elapsed_time = time.time() - tic
loss = F.mse_loss(
rgb[alive_ray_mask], pixels[alive_ray_mask]
)
print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | loss_interval={loss_interval:.5f} "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
)
_train()
if step >= 0 and step % 1000 == 0 and step > 0: if step >= 0 and step % 1000 == 0 and step > 0:
# evaluation # evaluation
...@@ -424,6 +437,7 @@ if __name__ == "__main__": ...@@ -424,6 +437,7 @@ if __name__ == "__main__":
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
cone_angle=args.cone_angle, cone_angle=args.cone_angle,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
proposal_nets_require_grads=False,
# test options # test options
test_chunk_size=args.test_chunk_size, test_chunk_size=args.test_chunk_size,
) )
......
...@@ -23,6 +23,7 @@ grid_query = _make_lazy_cuda_func("grid_query") ...@@ -23,6 +23,7 @@ grid_query = _make_lazy_cuda_func("grid_query")
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect") ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
ray_marching = _make_lazy_cuda_func("ray_marching") ray_marching = _make_lazy_cuda_func("ray_marching")
ray_marching_with_grid = _make_lazy_cuda_func("ray_marching_with_grid")
ray_resampling = _make_lazy_cuda_func("ray_resampling") ray_resampling = _make_lazy_cuda_func("ray_resampling")
ray_pdf_query = _make_lazy_cuda_func("ray_pdf_query") ray_pdf_query = _make_lazy_cuda_func("ray_pdf_query")
......
...@@ -13,6 +13,15 @@ std::vector<torch::Tensor> ray_aabb_intersect( ...@@ -13,6 +13,15 @@ std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor aabb); const torch::Tensor aabb);
std::vector<torch::Tensor> ray_marching( std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// sampling
const float step_size,
const float cone_angle);
std::vector<torch::Tensor> ray_marching_with_grid(
// rays // rays
const torch::Tensor rays_o, const torch::Tensor rays_o,
const torch::Tensor rays_d, const torch::Tensor rays_d,
...@@ -153,6 +162,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -153,6 +162,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
// marching // marching
m.def("ray_aabb_intersect", &ray_aabb_intersect); m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching); m.def("ray_marching", &ray_marching);
m.def("ray_marching_with_grid", &ray_marching_with_grid);
m.def("ray_resampling", &ray_resampling); m.def("ray_resampling", &ray_resampling);
m.def("ray_pdf_query", &ray_pdf_query); m.def("ray_pdf_query", &ray_pdf_query);
......
...@@ -76,7 +76,85 @@ inline __device__ __host__ float advance_to_next_voxel( ...@@ -76,7 +76,85 @@ inline __device__ __host__ float advance_to_next_voxel(
// Raymarching // Raymarching
// ------------------------------------------------------------------------------- // -------------------------------------------------------------------------------
__global__ void ray_marching_kernel( __global__ void ray_marching_kernel(
// rays info
const uint32_t n_rays,
const float *rays_o, // shape (n_rays, 3)
const float *rays_d, // shape (n_rays, 3)
const float *t_min, // shape (n_rays,)
const float *t_max, // shape (n_rays,)
// sampling
const float step_size,
const float cone_angle,
const int *packed_info,
// first round outputs
int *num_steps,
// second round outputs
int *ray_indices,
float *t_starts,
float *t_ends)
{
CUDA_GET_THREAD_ID(i, n_rays);
bool is_first_round = (packed_info == nullptr);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
if (is_first_round)
{
num_steps += i;
}
else
{
int base = packed_info[i * 2 + 0];
int steps = packed_info[i * 2 + 1];
t_starts += base;
t_ends += base;
ray_indices += base;
}
const float3 origin = make_float3(rays_o[0], rays_o[1], rays_o[2]);
const float3 dir = make_float3(rays_d[0], rays_d[1], rays_d[2]);
const float3 inv_dir = 1.0f / dir;
const float near = t_min[0], far = t_max[0];
float dt_min = step_size;
float dt_max = 1e10f;
int j = 0;
float t0 = near;
float dt = calc_dt(t0, cone_angle, dt_min, dt_max);
float t1 = t0 + dt;
float t_mid = (t0 + t1) * 0.5f;
while (t_mid < far)
{
if (!is_first_round)
{
t_starts[j] = t0;
t_ends[j] = t1;
ray_indices[j] = i;
}
++j;
// march to next sample
t0 = t1;
t1 = t0 + calc_dt(t0, cone_angle, dt_min, dt_max);
t_mid = (t0 + t1) * 0.5f;
}
if (is_first_round)
{
*num_steps = j;
}
return;
}
__global__ void ray_marching_with_grid_kernel(
// rays info // rays info
const uint32_t n_rays, const uint32_t n_rays,
const float *rays_o, // shape (n_rays, 3) const float *rays_o, // shape (n_rays, 3)
...@@ -189,7 +267,84 @@ __global__ void ray_marching_kernel( ...@@ -189,7 +267,84 @@ __global__ void ray_marching_kernel(
return; return;
} }
std::vector<torch::Tensor> ray_marching( std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// sampling
const float step_size,
const float cone_angle)
{
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(t_min.ndimension() == 1)
TORCH_CHECK(t_max.ndimension() == 1)
const int n_rays = rays_o.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor num_steps = torch::empty(
{n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// sampling
step_size,
cone_angle,
nullptr, /* packed_info */
// outputs
num_steps.data_ptr<int>(),
nullptr, /* ray_indices */
nullptr, /* t_starts */
nullptr /* t_ends */);
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
// output samples starts and ends
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor t_starts = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options());
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// sampling
step_size,
cone_angle,
packed_info.data_ptr<int>(),
// outputs
nullptr, /* num_steps */
ray_indices.data_ptr<int>(),
t_starts.data_ptr<float>(),
t_ends.data_ptr<float>());
return {packed_info, ray_indices, t_starts, t_ends};
}
std::vector<torch::Tensor> ray_marching_with_grid(
// rays // rays
const torch::Tensor rays_o, const torch::Tensor rays_o,
const torch::Tensor rays_d, const torch::Tensor rays_d,
...@@ -230,7 +385,7 @@ std::vector<torch::Tensor> ray_marching( ...@@ -230,7 +385,7 @@ std::vector<torch::Tensor> ray_marching(
{n_rays}, rays_o.options().dtype(torch::kInt32)); {n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray // count number of samples per ray
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( ray_marching_with_grid_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays // rays
n_rays, n_rays,
rays_o.data_ptr<float>(), rays_o.data_ptr<float>(),
...@@ -261,7 +416,7 @@ std::vector<torch::Tensor> ray_marching( ...@@ -261,7 +416,7 @@ std::vector<torch::Tensor> ray_marching(
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options()); torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options()); torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options());
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( ray_marching_with_grid_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays // rays
n_rays, n_rays,
rays_o.data_ptr<float>(), rays_o.data_ptr<float>(),
......
...@@ -5,11 +5,55 @@ import torch ...@@ -5,11 +5,55 @@ import torch
import nerfacc.cuda as _C import nerfacc.cuda as _C
from .cdf import ray_resampling from .cdf import ray_resampling
from .contraction import ContractionType
from .grid import Grid from .grid import Grid
from .intersection import ray_aabb_intersect from .intersection import ray_aabb_intersect
from .pack import pack_info, unpack_info from .pack import pack_info, unpack_info
from .vol_rendering import render_visibility, render_weight_from_density from .vol_rendering import (
render_visibility,
render_weight_from_alpha,
render_weight_from_density,
)
@torch.no_grad()
def maybe_filter(
t_starts: torch.Tensor,
t_ends: torch.Tensor,
ray_indices: torch.Tensor,
n_rays: int,
# sigma/alpha function for skipping invisible space
sigma_fn: Optional[Callable] = None,
alpha_fn: Optional[Callable] = None,
net: Optional[torch.nn.Module] = None,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
):
alphas = None
if sigma_fn is not None:
alpha_fn = lambda *args: 1.0 - torch.exp(
-sigma_fn(*args) * (t_ends - t_starts)
)
if alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long(), net)
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
n_rays=n_rays,
)
ray_indices, t_starts, t_ends, alphas = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
alphas[masks],
)
return ray_indices, t_starts, t_ends, alphas
@torch.no_grad() @torch.no_grad()
...@@ -30,6 +74,7 @@ def ray_marching( ...@@ -30,6 +74,7 @@ def ray_marching(
proposal_nets: Optional[torch.nn.Module] = None, proposal_nets: Optional[torch.nn.Module] = None,
early_stop_eps: float = 1e-4, early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0, alpha_thre: float = 0.0,
proposal_nets_require_grads: bool = True,
# rendering options # rendering options
near_plane: Optional[float] = None, near_plane: Optional[float] = None,
far_plane: Optional[float] = None, far_plane: Optional[float] = None,
...@@ -132,6 +177,9 @@ def ray_marching( ...@@ -132,6 +177,9 @@ def ray_marching(
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices] sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
""" """
torch.cuda.synchronize()
n_rays = rays_o.shape[0]
if not rays_o.is_cuda: if not rays_o.is_cuda:
raise NotImplementedError("Only support cuda inputs.") raise NotImplementedError("Only support cuda inputs.")
if alpha_fn is not None and sigma_fn is not None: if alpha_fn is not None and sigma_fn is not None:
...@@ -163,131 +211,83 @@ def ray_marching( ...@@ -163,131 +211,83 @@ def ray_marching(
# use grid for skipping if given # use grid for skipping if given
if grid is not None: if grid is not None:
grid_roi_aabb = grid.roi_aabb # marching with grid-based skipping
grid_binary = grid.binary packed_info, ray_indices, t_starts, t_ends = _C.ray_marching_with_grid(
contraction_type = grid.contraction_type.to_cpp_version() # rays
else: rays_o.contiguous(),
grid_roi_aabb = torch.tensor( rays_d.contiguous(),
[-1e10, -1e10, -1e10, 1e10, 1e10, 1e10], t_min.contiguous(),
dtype=torch.float32, t_max.contiguous(),
device=rays_o.device, # coontraction and grid
grid.roi_aabb.contiguous(),
grid.binary.contiguous(),
grid.contraction_type.to_cpp_version(),
# sampling
render_step_size,
cone_angle,
) )
grid_binary = torch.ones(
[1, 1, 1], dtype=torch.bool, device=rays_o.device else:
# marching
packed_info, ray_indices, t_starts, t_ends = _C.ray_marching(
# rays
rays_o.contiguous(),
rays_d.contiguous(),
t_min.contiguous(),
t_max.contiguous(),
# sampling
render_step_size,
cone_angle,
) )
contraction_type = ContractionType.AABB.to_cpp_version()
# marching with grid-based skipping
packed_info, ray_indices, t_starts, t_ends = _C.ray_marching(
# rays
rays_o.contiguous(),
rays_d.contiguous(),
t_min.contiguous(),
t_max.contiguous(),
# coontraction and grid
grid_roi_aabb.contiguous(),
grid_binary.contiguous(),
contraction_type,
# sampling
render_step_size,
cone_angle,
)
proposal_sample_list = [] proposal_sample_list = []
if proposal_nets is not None: if proposal_nets is not None:
# resample with proposal nets # resample with proposal nets
for net, num_samples in zip(proposal_nets, [32]): for net, num_samples in zip(proposal_nets, [32]):
with torch.no_grad(): ray_indices, t_starts, t_ends, alphas = maybe_filter(
# skip invisible space t_starts=t_starts,
if sigma_fn is not None or alpha_fn is not None: t_ends=t_ends,
# Query sigma without gradients ray_indices=ray_indices,
if sigma_fn is not None: n_rays=n_rays,
sigmas = sigma_fn( sigma_fn=sigma_fn,
t_starts, t_ends, ray_indices.long(), net=net alpha_fn=alpha_fn,
) net=net,
assert ( early_stop_eps=early_stop_eps,
sigmas.shape == t_starts.shape alpha_thre=alpha_thre,
), "sigmas must have shape of (N, 1)! Got {}".format( )
sigmas.shape packed_info = pack_info(ray_indices, n_rays=n_rays)
)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts)) if proposal_nets_require_grads:
elif alpha_fn is not None: with torch.enable_grad():
alphas = alpha_fn( sigmas = sigma_fn(
t_starts, t_ends, ray_indices.long(), net=net t_starts, t_ends, ray_indices.long(), net=net
)
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(
alphas.shape
)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
n_rays=rays_o.shape[0],
) )
ray_indices, t_starts, t_ends = ( weights = render_weight_from_density(
ray_indices[masks], t_starts, t_ends, sigmas, ray_indices=ray_indices
t_starts[masks],
t_ends[masks],
) )
# print( proposal_sample_list.append(
# alphas.shape, (packed_info, t_starts, t_ends, weights)
# masks.float().sum(), )
# alphas.min(), else:
# alphas.max(), weights = render_weight_from_alpha(
# ) alphas, ray_indices=ray_indices
with torch.enable_grad():
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long(), net=net)
weights = render_weight_from_density(
t_starts, t_ends, sigmas, ray_indices=ray_indices
)
packed_info = pack_info(ray_indices, n_rays=rays_o.shape[0])
proposal_sample_list.append(
(packed_info, t_starts, t_ends, weights)
) )
packed_info, t_starts, t_ends = ray_resampling( packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=num_samples packed_info, t_starts, t_ends, weights, n_samples=num_samples
) )
ray_indices = unpack_info(packed_info, n_samples=t_starts.shape[0]) ray_indices = unpack_info(packed_info, n_samples=t_starts.shape[0])
with torch.no_grad(): ray_indices, t_starts, t_ends, _ = maybe_filter(
# skip invisible space t_starts=t_starts,
if sigma_fn is not None or alpha_fn is not None: t_ends=t_ends,
# Query sigma without gradients ray_indices=ray_indices,
if sigma_fn is not None: n_rays=n_rays,
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long()) sigma_fn=sigma_fn,
assert ( alpha_fn=alpha_fn,
sigmas.shape == t_starts.shape net=None,
), "sigmas must have shape of (N, 1)! Got {}".format( early_stop_eps=early_stop_eps,
sigmas.shape alpha_thre=alpha_thre,
) )
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
elif alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long())
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(
alphas.shape
)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
n_rays=rays_o.shape[0],
)
ray_indices, t_starts, t_ends = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
)
if proposal_nets is not None: if proposal_nets is not None:
return ray_indices, t_starts, t_ends, proposal_sample_list return ray_indices, t_starts, t_ends, proposal_sample_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment