Commit ad2a0079 authored by Ruilong Li's avatar Ruilong Li
Browse files

proposal_nets_require_grads: 7k, 226s, 34.96db, loss 57

parent 6f7f9fb0
...@@ -42,6 +42,7 @@ def render_image( ...@@ -42,6 +42,7 @@ def render_image(
render_bkgd: Optional[torch.Tensor] = None, render_bkgd: Optional[torch.Tensor] = None,
cone_angle: float = 0.0, cone_angle: float = 0.0,
alpha_thre: float = 0.0, alpha_thre: float = 0.0,
proposal_nets_require_grads: bool = True,
# test options # test options
test_chunk_size: int = 8192, test_chunk_size: int = 8192,
): ):
...@@ -94,6 +95,7 @@ def render_image( ...@@ -94,6 +95,7 @@ def render_image(
stratified=radiance_field.training, stratified=radiance_field.training,
cone_angle=cone_angle, cone_angle=cone_angle,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
proposal_nets_require_grads=proposal_nets_require_grads,
) )
rgb, opacity, depth, weights = rendering( rgb, opacity, depth, weights = rendering(
t_starts, t_starts,
...@@ -312,6 +314,8 @@ if __name__ == "__main__": ...@@ -312,6 +314,8 @@ if __name__ == "__main__":
radiance_field.train() radiance_field.train()
proposal_nets.train() proposal_nets.train()
# @profile
def _train():
data = train_dataset[i] data = train_dataset[i]
render_bkgd = data["color_bkgd"] render_bkgd = data["color_bkgd"]
...@@ -337,9 +341,10 @@ if __name__ == "__main__": ...@@ -337,9 +341,10 @@ if __name__ == "__main__":
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
cone_angle=args.cone_angle, cone_angle=args.cone_angle,
alpha_thre=min(alpha_thre, alpha_thre * step / 1000), alpha_thre=min(alpha_thre, alpha_thre * step / 1000),
proposal_nets_require_grads=(step < 100 or step % 16 == 0),
) )
if n_rendering_samples == 0: # if n_rendering_samples == 0:
continue # continue
# dynamic batch size for rays to keep sample batch size constant. # dynamic batch size for rays to keep sample batch size constant.
num_rays = len(pixels) num_rays = len(pixels)
...@@ -351,7 +356,9 @@ if __name__ == "__main__": ...@@ -351,7 +356,9 @@ if __name__ == "__main__":
alive_ray_mask = acc.squeeze(-1) > 0 alive_ray_mask = acc.squeeze(-1) > 0
# compute loss # compute loss
loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) loss = F.smooth_l1_loss(
rgb[alive_ray_mask], pixels[alive_ray_mask]
)
( (
packed_info, packed_info,
...@@ -377,7 +384,9 @@ if __name__ == "__main__": ...@@ -377,7 +384,9 @@ if __name__ == "__main__":
).detach() ).detach()
loss_interval = ( loss_interval = (
torch.clamp(proposal_weights_gt - proposal_weights, min=0) torch.clamp(
proposal_weights_gt - proposal_weights, min=0
)
) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps) ) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps)
loss_interval = loss_interval.mean() loss_interval = loss_interval.mean()
loss += loss_interval * 1.0 loss += loss_interval * 1.0
...@@ -390,7 +399,9 @@ if __name__ == "__main__": ...@@ -390,7 +399,9 @@ if __name__ == "__main__":
if step % 100 == 0: if step % 100 == 0:
elapsed_time = time.time() - tic elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) loss = F.mse_loss(
rgb[alive_ray_mask], pixels[alive_ray_mask]
)
print( print(
f"elapsed_time={elapsed_time:.2f}s | step={step} | " f"elapsed_time={elapsed_time:.2f}s | step={step} | "
f"loss={loss:.5f} | loss_interval={loss_interval:.5f} " f"loss={loss:.5f} | loss_interval={loss_interval:.5f} "
...@@ -398,6 +409,8 @@ if __name__ == "__main__": ...@@ -398,6 +409,8 @@ if __name__ == "__main__":
f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |"
) )
_train()
if step >= 0 and step % 1000 == 0 and step > 0: if step >= 0 and step % 1000 == 0 and step > 0:
# evaluation # evaluation
radiance_field.eval() radiance_field.eval()
...@@ -424,6 +437,7 @@ if __name__ == "__main__": ...@@ -424,6 +437,7 @@ if __name__ == "__main__":
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
cone_angle=args.cone_angle, cone_angle=args.cone_angle,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
proposal_nets_require_grads=False,
# test options # test options
test_chunk_size=args.test_chunk_size, test_chunk_size=args.test_chunk_size,
) )
......
...@@ -23,6 +23,7 @@ grid_query = _make_lazy_cuda_func("grid_query") ...@@ -23,6 +23,7 @@ grid_query = _make_lazy_cuda_func("grid_query")
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect") ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
ray_marching = _make_lazy_cuda_func("ray_marching") ray_marching = _make_lazy_cuda_func("ray_marching")
ray_marching_with_grid = _make_lazy_cuda_func("ray_marching_with_grid")
ray_resampling = _make_lazy_cuda_func("ray_resampling") ray_resampling = _make_lazy_cuda_func("ray_resampling")
ray_pdf_query = _make_lazy_cuda_func("ray_pdf_query") ray_pdf_query = _make_lazy_cuda_func("ray_pdf_query")
......
...@@ -13,6 +13,15 @@ std::vector<torch::Tensor> ray_aabb_intersect( ...@@ -13,6 +13,15 @@ std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor aabb); const torch::Tensor aabb);
std::vector<torch::Tensor> ray_marching( std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// sampling
const float step_size,
const float cone_angle);
std::vector<torch::Tensor> ray_marching_with_grid(
// rays // rays
const torch::Tensor rays_o, const torch::Tensor rays_o,
const torch::Tensor rays_d, const torch::Tensor rays_d,
...@@ -153,6 +162,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -153,6 +162,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
// marching // marching
m.def("ray_aabb_intersect", &ray_aabb_intersect); m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching); m.def("ray_marching", &ray_marching);
m.def("ray_marching_with_grid", &ray_marching_with_grid);
m.def("ray_resampling", &ray_resampling); m.def("ray_resampling", &ray_resampling);
m.def("ray_pdf_query", &ray_pdf_query); m.def("ray_pdf_query", &ray_pdf_query);
......
...@@ -76,7 +76,85 @@ inline __device__ __host__ float advance_to_next_voxel( ...@@ -76,7 +76,85 @@ inline __device__ __host__ float advance_to_next_voxel(
// Raymarching // Raymarching
// ------------------------------------------------------------------------------- // -------------------------------------------------------------------------------
__global__ void ray_marching_kernel( __global__ void ray_marching_kernel(
// rays info
const uint32_t n_rays,
const float *rays_o, // shape (n_rays, 3)
const float *rays_d, // shape (n_rays, 3)
const float *t_min, // shape (n_rays,)
const float *t_max, // shape (n_rays,)
// sampling
const float step_size,
const float cone_angle,
const int *packed_info,
// first round outputs
int *num_steps,
// second round outputs
int *ray_indices,
float *t_starts,
float *t_ends)
{
CUDA_GET_THREAD_ID(i, n_rays);
bool is_first_round = (packed_info == nullptr);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
if (is_first_round)
{
num_steps += i;
}
else
{
int base = packed_info[i * 2 + 0];
int steps = packed_info[i * 2 + 1];
t_starts += base;
t_ends += base;
ray_indices += base;
}
const float3 origin = make_float3(rays_o[0], rays_o[1], rays_o[2]);
const float3 dir = make_float3(rays_d[0], rays_d[1], rays_d[2]);
const float3 inv_dir = 1.0f / dir;
const float near = t_min[0], far = t_max[0];
float dt_min = step_size;
float dt_max = 1e10f;
int j = 0;
float t0 = near;
float dt = calc_dt(t0, cone_angle, dt_min, dt_max);
float t1 = t0 + dt;
float t_mid = (t0 + t1) * 0.5f;
while (t_mid < far)
{
if (!is_first_round)
{
t_starts[j] = t0;
t_ends[j] = t1;
ray_indices[j] = i;
}
++j;
// march to next sample
t0 = t1;
t1 = t0 + calc_dt(t0, cone_angle, dt_min, dt_max);
t_mid = (t0 + t1) * 0.5f;
}
if (is_first_round)
{
*num_steps = j;
}
return;
}
__global__ void ray_marching_with_grid_kernel(
// rays info // rays info
const uint32_t n_rays, const uint32_t n_rays,
const float *rays_o, // shape (n_rays, 3) const float *rays_o, // shape (n_rays, 3)
...@@ -189,7 +267,84 @@ __global__ void ray_marching_kernel( ...@@ -189,7 +267,84 @@ __global__ void ray_marching_kernel(
return; return;
} }
std::vector<torch::Tensor> ray_marching( std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// sampling
const float step_size,
const float cone_angle)
{
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(t_min.ndimension() == 1)
TORCH_CHECK(t_max.ndimension() == 1)
const int n_rays = rays_o.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor num_steps = torch::empty(
{n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// sampling
step_size,
cone_angle,
nullptr, /* packed_info */
// outputs
num_steps.data_ptr<int>(),
nullptr, /* ray_indices */
nullptr, /* t_starts */
nullptr /* t_ends */);
torch::Tensor cum_steps = num_steps.cumsum(0, torch::kInt32);
torch::Tensor packed_info = torch::stack({cum_steps - num_steps, num_steps}, 1);
// output samples starts and ends
int total_steps = cum_steps[cum_steps.size(0) - 1].item<int>();
torch::Tensor t_starts = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options());
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// sampling
step_size,
cone_angle,
packed_info.data_ptr<int>(),
// outputs
nullptr, /* num_steps */
ray_indices.data_ptr<int>(),
t_starts.data_ptr<float>(),
t_ends.data_ptr<float>());
return {packed_info, ray_indices, t_starts, t_ends};
}
std::vector<torch::Tensor> ray_marching_with_grid(
// rays // rays
const torch::Tensor rays_o, const torch::Tensor rays_o,
const torch::Tensor rays_d, const torch::Tensor rays_d,
...@@ -230,7 +385,7 @@ std::vector<torch::Tensor> ray_marching( ...@@ -230,7 +385,7 @@ std::vector<torch::Tensor> ray_marching(
{n_rays}, rays_o.options().dtype(torch::kInt32)); {n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray // count number of samples per ray
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( ray_marching_with_grid_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays // rays
n_rays, n_rays,
rays_o.data_ptr<float>(), rays_o.data_ptr<float>(),
...@@ -261,7 +416,7 @@ std::vector<torch::Tensor> ray_marching( ...@@ -261,7 +416,7 @@ std::vector<torch::Tensor> ray_marching(
torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options()); torch::Tensor t_ends = torch::empty({total_steps, 1}, rays_o.options());
torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options()); torch::Tensor ray_indices = torch::empty({total_steps}, cum_steps.options());
ray_marching_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( ray_marching_with_grid_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays // rays
n_rays, n_rays,
rays_o.data_ptr<float>(), rays_o.data_ptr<float>(),
......
...@@ -5,11 +5,55 @@ import torch ...@@ -5,11 +5,55 @@ import torch
import nerfacc.cuda as _C import nerfacc.cuda as _C
from .cdf import ray_resampling from .cdf import ray_resampling
from .contraction import ContractionType
from .grid import Grid from .grid import Grid
from .intersection import ray_aabb_intersect from .intersection import ray_aabb_intersect
from .pack import pack_info, unpack_info from .pack import pack_info, unpack_info
from .vol_rendering import render_visibility, render_weight_from_density from .vol_rendering import (
render_visibility,
render_weight_from_alpha,
render_weight_from_density,
)
@torch.no_grad()
def maybe_filter(
t_starts: torch.Tensor,
t_ends: torch.Tensor,
ray_indices: torch.Tensor,
n_rays: int,
# sigma/alpha function for skipping invisible space
sigma_fn: Optional[Callable] = None,
alpha_fn: Optional[Callable] = None,
net: Optional[torch.nn.Module] = None,
early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0,
):
alphas = None
if sigma_fn is not None:
alpha_fn = lambda *args: 1.0 - torch.exp(
-sigma_fn(*args) * (t_ends - t_starts)
)
if alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long(), net)
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
n_rays=n_rays,
)
ray_indices, t_starts, t_ends, alphas = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
alphas[masks],
)
return ray_indices, t_starts, t_ends, alphas
@torch.no_grad() @torch.no_grad()
...@@ -30,6 +74,7 @@ def ray_marching( ...@@ -30,6 +74,7 @@ def ray_marching(
proposal_nets: Optional[torch.nn.Module] = None, proposal_nets: Optional[torch.nn.Module] = None,
early_stop_eps: float = 1e-4, early_stop_eps: float = 1e-4,
alpha_thre: float = 0.0, alpha_thre: float = 0.0,
proposal_nets_require_grads: bool = True,
# rendering options # rendering options
near_plane: Optional[float] = None, near_plane: Optional[float] = None,
far_plane: Optional[float] = None, far_plane: Optional[float] = None,
...@@ -132,6 +177,9 @@ def ray_marching( ...@@ -132,6 +177,9 @@ def ray_marching(
sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices] sample_locs = rays_o[ray_indices] + t_mid * rays_d[ray_indices]
""" """
torch.cuda.synchronize()
n_rays = rays_o.shape[0]
if not rays_o.is_cuda: if not rays_o.is_cuda:
raise NotImplementedError("Only support cuda inputs.") raise NotImplementedError("Only support cuda inputs.")
if alpha_fn is not None and sigma_fn is not None: if alpha_fn is not None and sigma_fn is not None:
...@@ -163,31 +211,30 @@ def ray_marching( ...@@ -163,31 +211,30 @@ def ray_marching(
# use grid for skipping if given # use grid for skipping if given
if grid is not None: if grid is not None:
grid_roi_aabb = grid.roi_aabb # marching with grid-based skipping
grid_binary = grid.binary packed_info, ray_indices, t_starts, t_ends = _C.ray_marching_with_grid(
contraction_type = grid.contraction_type.to_cpp_version() # rays
else: rays_o.contiguous(),
grid_roi_aabb = torch.tensor( rays_d.contiguous(),
[-1e10, -1e10, -1e10, 1e10, 1e10, 1e10], t_min.contiguous(),
dtype=torch.float32, t_max.contiguous(),
device=rays_o.device, # coontraction and grid
) grid.roi_aabb.contiguous(),
grid_binary = torch.ones( grid.binary.contiguous(),
[1, 1, 1], dtype=torch.bool, device=rays_o.device grid.contraction_type.to_cpp_version(),
# sampling
render_step_size,
cone_angle,
) )
contraction_type = ContractionType.AABB.to_cpp_version()
# marching with grid-based skipping else:
# marching
packed_info, ray_indices, t_starts, t_ends = _C.ray_marching( packed_info, ray_indices, t_starts, t_ends = _C.ray_marching(
# rays # rays
rays_o.contiguous(), rays_o.contiguous(),
rays_d.contiguous(), rays_d.contiguous(),
t_min.contiguous(), t_min.contiguous(),
t_max.contiguous(), t_max.contiguous(),
# coontraction and grid
grid_roi_aabb.contiguous(),
grid_binary.contiguous(),
contraction_type,
# sampling # sampling
render_step_size, render_step_size,
cone_angle, cone_angle,
...@@ -197,96 +244,49 @@ def ray_marching( ...@@ -197,96 +244,49 @@ def ray_marching(
if proposal_nets is not None: if proposal_nets is not None:
# resample with proposal nets # resample with proposal nets
for net, num_samples in zip(proposal_nets, [32]): for net, num_samples in zip(proposal_nets, [32]):
with torch.no_grad(): ray_indices, t_starts, t_ends, alphas = maybe_filter(
# skip invisible space t_starts=t_starts,
if sigma_fn is not None or alpha_fn is not None: t_ends=t_ends,
# Query sigma without gradients
if sigma_fn is not None:
sigmas = sigma_fn(
t_starts, t_ends, ray_indices.long(), net=net
)
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(
sigmas.shape
)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
elif alpha_fn is not None:
alphas = alpha_fn(
t_starts, t_ends, ray_indices.long(), net=net
)
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(
alphas.shape
)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices, ray_indices=ray_indices,
n_rays=n_rays,
sigma_fn=sigma_fn,
alpha_fn=alpha_fn,
net=net,
early_stop_eps=early_stop_eps, early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
n_rays=rays_o.shape[0],
)
ray_indices, t_starts, t_ends = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
) )
# print( packed_info = pack_info(ray_indices, n_rays=n_rays)
# alphas.shape,
# masks.float().sum(),
# alphas.min(),
# alphas.max(),
# )
if proposal_nets_require_grads:
with torch.enable_grad(): with torch.enable_grad():
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long(), net=net) sigmas = sigma_fn(
t_starts, t_ends, ray_indices.long(), net=net
)
weights = render_weight_from_density( weights = render_weight_from_density(
t_starts, t_ends, sigmas, ray_indices=ray_indices t_starts, t_ends, sigmas, ray_indices=ray_indices
) )
packed_info = pack_info(ray_indices, n_rays=rays_o.shape[0])
proposal_sample_list.append( proposal_sample_list.append(
(packed_info, t_starts, t_ends, weights) (packed_info, t_starts, t_ends, weights)
) )
else:
weights = render_weight_from_alpha(
alphas, ray_indices=ray_indices
)
packed_info, t_starts, t_ends = ray_resampling( packed_info, t_starts, t_ends = ray_resampling(
packed_info, t_starts, t_ends, weights, n_samples=num_samples packed_info, t_starts, t_ends, weights, n_samples=num_samples
) )
ray_indices = unpack_info(packed_info, n_samples=t_starts.shape[0]) ray_indices = unpack_info(packed_info, n_samples=t_starts.shape[0])
with torch.no_grad(): ray_indices, t_starts, t_ends, _ = maybe_filter(
# skip invisible space t_starts=t_starts,
if sigma_fn is not None or alpha_fn is not None: t_ends=t_ends,
# Query sigma without gradients
if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(
sigmas.shape
)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
elif alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long())
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(
alphas.shape
)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices, ray_indices=ray_indices,
n_rays=n_rays,
sigma_fn=sigma_fn,
alpha_fn=alpha_fn,
net=None,
early_stop_eps=early_stop_eps, early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre, alpha_thre=alpha_thre,
n_rays=rays_o.shape[0],
)
ray_indices, t_starts, t_ends = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
) )
if proposal_nets is not None: if proposal_nets is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment