"vscode:/vscode.git/clone" did not exist on "661f8f1637938cda38c9b5dbacc4d6fffa3d8a03"
Commit 6f7f9fb0 authored by Ruilong Li's avatar Ruilong Li
Browse files

voxel as proposal: 7k; 292s, 54k rays, loss 0.00076

parent 477ff3df
...@@ -111,6 +111,7 @@ class NGPradianceField(torch.nn.Module): ...@@ -111,6 +111,7 @@ class NGPradianceField(torch.nn.Module):
}, },
) )
if hidden_dim > 0:
self.mlp_base = tcnn.NetworkWithInputEncoding( self.mlp_base = tcnn.NetworkWithInputEncoding(
n_input_dims=num_dim, n_input_dims=num_dim,
n_output_dims=1 + self.geo_feat_dim, n_output_dims=1 + self.geo_feat_dim,
...@@ -130,6 +131,19 @@ class NGPradianceField(torch.nn.Module): ...@@ -130,6 +131,19 @@ class NGPradianceField(torch.nn.Module):
"n_hidden_layers": 1, "n_hidden_layers": 1,
}, },
) )
else:
self.mlp_base = tcnn.Encoding(
n_input_dims=num_dim,
encoding_config={
"otype": "HashGrid",
"n_levels": 1,
"n_features_per_level": 1,
"log2_hashmap_size": 21,
"base_resolution": 128,
"per_level_scale": 1.0,
},
)
if self.geo_feat_dim > 0: if self.geo_feat_dim > 0:
self.mlp_head = tcnn.Network( self.mlp_head = tcnn.Network(
n_input_dims=( n_input_dims=(
......
...@@ -262,16 +262,22 @@ if __name__ == "__main__": ...@@ -262,16 +262,22 @@ if __name__ == "__main__":
NGPradianceField( NGPradianceField(
aabb=args.aabb, aabb=args.aabb,
use_viewdirs=False, use_viewdirs=False,
hidden_dim=16, hidden_dim=0,
max_res=64,
geo_feat_dim=0, geo_feat_dim=0,
n_levels=4,
log2_hashmap_size=19,
), ),
# NGPradianceField( # NGPradianceField(
# aabb=args.aabb, # aabb=args.aabb,
# use_viewdirs=False, # use_viewdirs=False,
# hidden_dim=16, # hidden_dim=16,
# max_res=64,
# geo_feat_dim=0,
# n_levels=4,
# log2_hashmap_size=19,
# ),
# NGPradianceField(
# aabb=args.aabb,
# use_viewdirs=False,
# hidden_dim=16,
# max_res=256, # max_res=256,
# geo_feat_dim=0, # geo_feat_dim=0,
# n_levels=5, # n_levels=5,
...@@ -374,7 +380,7 @@ if __name__ == "__main__": ...@@ -374,7 +380,7 @@ if __name__ == "__main__":
torch.clamp(proposal_weights_gt - proposal_weights, min=0) torch.clamp(proposal_weights_gt - proposal_weights, min=0)
) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps) ) ** 2 / (proposal_weights + torch.finfo(torch.float32).eps)
loss_interval = loss_interval.mean() loss_interval = loss_interval.mean()
loss += loss_interval * 0.1 loss += loss_interval * 1.0
optimizer.zero_grad() optimizer.zero_grad()
# do not unscale it because we are using Adam. # do not unscale it because we are using Adam.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment