Commit b42d62cc authored by Ruilong Li's avatar Ruilong Li
Browse files

fix tests; trainable; checking efficiency

parent 822d5199
......@@ -28,6 +28,7 @@ def set_random_seed(seed):
torch.manual_seed(seed)
# @profile
def render_image(
# scene
radiance_field: torch.nn.Module,
......@@ -169,7 +170,7 @@ if __name__ == "__main__":
parser.add_argument(
"--test_chunk_size",
type=int,
default=8192,
default=1024,
)
parser.add_argument(
"--unbounded",
......@@ -264,8 +265,8 @@ if __name__ == "__main__":
hidden_dim=16,
max_res=64,
geo_feat_dim=0,
n_levels=2,
log2_hashmap_size=17,
n_levels=4,
log2_hashmap_size=19,
),
# NGPradianceField(
# aabb=args.aabb,
......@@ -303,6 +304,8 @@ if __name__ == "__main__":
for epoch in range(10000000):
for i in range(len(train_dataset)):
radiance_field.train()
proposal_nets.train()
data = train_dataset[i]
render_bkgd = data["color_bkgd"]
......@@ -350,6 +353,7 @@ if __name__ == "__main__":
t_ends,
weights,
) = proposal_sample_list[-1]
loss_interval = 0.0
for (
proposal_packed_info,
proposal_t_starts,
......@@ -365,7 +369,6 @@ if __name__ == "__main__":
proposal_t_starts,
proposal_t_ends,
).detach()
torch.cuda.synchronize()
loss_interval = (
torch.clamp(proposal_weights_gt - proposal_weights, min=0)
......@@ -392,6 +395,7 @@ if __name__ == "__main__":
if step >= 0 and step % 1000 == 0 and step > 0:
# evaluation
radiance_field.eval()
proposal_nets.eval()
psnrs = []
with torch.no_grad():
......
......@@ -8,11 +8,12 @@ from .cdf import ray_resampling
from .contraction import ContractionType
from .grid import Grid
from .intersection import ray_aabb_intersect
from .pack import unpack_info
from .pack import pack_info, unpack_info
from .vol_rendering import render_visibility, render_weight_from_density
@torch.no_grad()
# @profile
def ray_marching(
# rays
rays_o: torch.Tensor,
......@@ -192,15 +193,60 @@ def ray_marching(
cone_angle,
)
proposal_sample_list = []
if proposal_nets is not None:
proposal_sample_list = []
# resample with proposal nets
for net, num_samples in zip(proposal_nets, [32]):
for net, num_samples in zip(proposal_nets, [64]):
with torch.no_grad():
# skip invisible space
if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients
if sigma_fn is not None:
sigmas = sigma_fn(
t_starts, t_ends, ray_indices.long(), net=net
)
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(
sigmas.shape
)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
elif alpha_fn is not None:
alphas = alpha_fn(
t_starts, t_ends, ray_indices.long(), net=net
)
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(
alphas.shape
)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices,
early_stop_eps=early_stop_eps,
alpha_thre=min(alphas.mean().item(), 1e-1),
n_rays=rays_o.shape[0],
)
ray_indices, t_starts, t_ends = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
)
# print(
# alphas.shape,
# masks.float().sum(),
# alphas.min(),
# alphas.max(),
# )
with torch.enable_grad():
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long(), net=net)
weights = render_weight_from_density(
t_starts, t_ends, sigmas, ray_indices=ray_indices
)
packed_info = pack_info(ray_indices, n_rays=rays_o.shape[0])
proposal_sample_list.append(
(packed_info, t_starts, t_ends, weights)
)
......@@ -209,35 +255,6 @@ def ray_marching(
)
ray_indices = unpack_info(packed_info, n_samples=t_starts.shape[0])
# skip invisible space
if sigma_fn is not None or alpha_fn is not None:
# Query sigma without gradients
if sigma_fn is not None:
sigmas = sigma_fn(t_starts, t_ends, ray_indices.long())
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N, 1)! Got {}".format(sigmas.shape)
alphas = 1.0 - torch.exp(-sigmas * (t_ends - t_starts))
elif alpha_fn is not None:
alphas = alpha_fn(t_starts, t_ends, ray_indices.long())
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N, 1)! Got {}".format(alphas.shape)
# Compute visibility of the samples, and filter out invisible samples
masks = render_visibility(
alphas,
ray_indices=ray_indices,
early_stop_eps=early_stop_eps,
alpha_thre=alpha_thre,
n_rays=rays_o.shape[0],
)
ray_indices, t_starts, t_ends = (
ray_indices[masks],
t_starts[masks],
t_ends[masks],
)
if proposal_nets is not None:
return ray_indices, t_starts, t_ends, proposal_sample_list
else:
......
......@@ -50,7 +50,7 @@ def test_pdf_query():
render_step_size=0.2,
)
packed_info = pack_info(ray_indices, n_rays)
weights = torch.rand((t_starts.shape[0],), device=device)
weights = torch.rand((t_starts.shape[0], 1), device=device)
packed_info_new = packed_info
t_starts_new = t_starts - 0.3
......
......@@ -105,8 +105,8 @@ def test_resampling():
t_starts = t[:, :-1][masks].unsqueeze(-1)
t_ends = t[:, 1:][masks].unsqueeze(-1)
w_logits = w_logits[masks]
w = w[masks]
w_logits = w_logits[masks].unsqueeze(-1)
w = w[masks].unsqueeze(-1)
num_steps = masks.long().sum(dim=-1)
cum_steps = torch.cumsum(num_steps, dim=0)
packed_info = torch.stack([cum_steps - num_steps, num_steps], dim=-1).int()
......@@ -143,7 +143,7 @@ def test_pdf_query():
)
packed_info = pack_info(ray_indices, rays_o.shape[0])
weights = torch.rand((t_starts.shape[0],), device=device)
weights = torch.rand((t_starts.shape[0], 1), device=device)
weights_new = ray_pdf_query(
packed_info,
t_starts,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment