Commit f9a6ae6a authored by Ruilong Li's avatar Ruilong Li
Browse files

ngp args

parent 346dd51a
...@@ -73,6 +73,9 @@ class NGPradianceField(torch.nn.Module): ...@@ -73,6 +73,9 @@ class NGPradianceField(torch.nn.Module):
use_viewdirs: bool = True, use_viewdirs: bool = True,
density_activation: Callable = lambda x: trunc_exp(x - 1), density_activation: Callable = lambda x: trunc_exp(x - 1),
unbounded: bool = False, unbounded: bool = False,
geo_feat_dim: int = 15,
n_levels: int = 16,
log2_hashmap_size: int = 19,
) -> None: ) -> None:
super().__init__() super().__init__()
if not isinstance(aabb, torch.Tensor): if not isinstance(aabb, torch.Tensor):
...@@ -83,7 +86,7 @@ class NGPradianceField(torch.nn.Module): ...@@ -83,7 +86,7 @@ class NGPradianceField(torch.nn.Module):
self.density_activation = density_activation self.density_activation = density_activation
self.unbounded = unbounded self.unbounded = unbounded
self.geo_feat_dim = 15 self.geo_feat_dim = geo_feat_dim
per_level_scale = 1.4472692012786865 per_level_scale = 1.4472692012786865
if self.use_viewdirs: if self.use_viewdirs:
...@@ -107,9 +110,9 @@ class NGPradianceField(torch.nn.Module): ...@@ -107,9 +110,9 @@ class NGPradianceField(torch.nn.Module):
n_output_dims=1 + self.geo_feat_dim, n_output_dims=1 + self.geo_feat_dim,
encoding_config={ encoding_config={
"otype": "HashGrid", "otype": "HashGrid",
"n_levels": 16, "n_levels": n_levels,
"n_features_per_level": 2, "n_features_per_level": 2,
"log2_hashmap_size": 19, "log2_hashmap_size": log2_hashmap_size,
"base_resolution": 16, "base_resolution": 16,
"per_level_scale": per_level_scale, "per_level_scale": per_level_scale,
}, },
...@@ -121,25 +124,25 @@ class NGPradianceField(torch.nn.Module): ...@@ -121,25 +124,25 @@ class NGPradianceField(torch.nn.Module):
"n_hidden_layers": 1, "n_hidden_layers": 1,
}, },
) )
if self.geo_feat_dim > 0:
self.mlp_head = tcnn.Network( self.mlp_head = tcnn.Network(
n_input_dims=( n_input_dims=(
( (
self.direction_encoding.n_output_dims self.direction_encoding.n_output_dims
if self.use_viewdirs if self.use_viewdirs
else 0 else 0
) )
+ self.geo_feat_dim + self.geo_feat_dim
), ),
n_output_dims=3, n_output_dims=3,
network_config={ network_config={
"otype": "FullyFusedMLP", "otype": "FullyFusedMLP",
"activation": "ReLU", "activation": "ReLU",
"output_activation": "Sigmoid", "output_activation": "Sigmoid",
"n_neurons": 64, "n_neurons": 64,
"n_hidden_layers": 2, "n_hidden_layers": 2,
}, },
) )
def query_density(self, x, return_feat: bool = False): def query_density(self, x, return_feat: bool = False):
if self.unbounded: if self.unbounded:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment