Commit 6f7f9fb0 authored by Ruilong Li's avatar Ruilong Li
Browse files

voxel as proposal: 7k; 292s, 54k rays, loss 0.00076

parent 477ff3df
......@@ -111,25 +111,39 @@ class NGPradianceField(torch.nn.Module):
},
)
self.mlp_base = tcnn.NetworkWithInputEncoding(
n_input_dims=num_dim,
n_output_dims=1 + self.geo_feat_dim,
encoding_config={
"otype": "HashGrid",
"n_levels": n_levels,
"n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": base_res,
"per_level_scale": per_level_scale,
},
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "None",
"n_neurons": hidden_dim,
"n_hidden_layers": 1,
},
)
if hidden_dim > 0:
self.mlp_base = tcnn.NetworkWithInputEncoding(
n_input_dims=num_dim,
n_output_dims=1 + self.geo_feat_dim,
encoding_config={
"otype": "HashGrid",
"n_levels": n_levels,
"n_features_per_level": 2,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": base_res,
"per_level_scale": per_level_scale,
},
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "None",
"n_neurons": hidden_dim,
"n_hidden_layers": 1,
},
)
else:
self.mlp_base = tcnn.Encoding(
n_input_dims=num_dim,
encoding_config={
"otype": "HashGrid",
"n_levels": 1,
"n_features_per_level": 1,
"log2_hashmap_size": 21,
"base_resolution": 128,
"per_level_scale": 1.0,
},
)
if self.geo_feat_dim > 0:
self.mlp_head = tcnn.Network(
n_input_dims=(
......
......@@ -262,16 +262,22 @@ if __name__ == "__main__":
NGPradianceField(
aabb=args.aabb,
use_viewdirs=False,
hidden_dim=16,
max_res=64,
hidden_dim=0,
geo_feat_dim=0,
n_levels=4,
log2_hashmap_size=19,
),
# NGPradianceField(
# aabb=args.aabb,
# use_viewdirs=False,
# hidden_dim=16,
# max_res=64,
# geo_feat_dim=0,
# n_levels=4,
# log2_hashmap_size=19,
# ),
# NGPradianceField(
# aabb=args.aabb,
# use_viewdirs=False,
# hidden_dim=16,
# max_res=256,
# geo_feat_dim=0,
# n_levels=5,
......@@ -374,7 +380,7 @@ if __name__ == "__main__":
torch.clamp(proposal_weights_gt - proposal_weights, min=0)
) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps)
loss_interval = loss_interval.mean()
loss += loss_interval * 0.1
loss += loss_interval * 1.0
optimizer.zero_grad()
# do not unscale it because we are using Adam.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment