"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "ffe7b93b60e037b2b1ec056b88cba7c14fc3f9e3"
Commit 8df25048 authored by Ruilong Li's avatar Ruilong Li
Browse files

ndr trial

parent 3c876cf1
...@@ -281,3 +281,115 @@ class TNeRFRadianceField(nn.Module): ...@@ -281,3 +281,115 @@ class TNeRFRadianceField(nn.Module):
torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1) torch.cat([self.posi_encoder(x), self.time_encoder(t)], dim=-1)
) )
return self.nerf(x, condition=condition) return self.nerf(x, condition=condition)
class NDRTNeRFRadianceField(nn.Module):
"""Invertble NN from https://arxiv.org/pdf/2206.15258.pdf"""
def __init__(self) -> None:
super().__init__()
self.time_encoder = SinusoidalEncoder(1, 0, 4, True)
self.warp_layers_1 = nn.ModuleList()
self.time_layers_1 = nn.ModuleList()
self.warp_layers_2 = nn.ModuleList()
self.time_layers_2 = nn.ModuleList()
self.posi_encoder_1 = SinusoidalEncoder(2, 0, 4, True)
self.posi_encoder_2 = SinusoidalEncoder(1, 0, 4, True)
for _ in range(3):
self.warp_layers_1.append(
MLP(
input_dim=self.posi_encoder_1.latent_dim + 64,
output_dim=1,
net_depth=2,
net_width=128,
skip_layer=None,
output_init=functools.partial(
torch.nn.init.uniform_, b=1e-4
),
)
)
self.warp_layers_2.append(
MLP(
input_dim=self.posi_encoder_2.latent_dim + 64,
output_dim=1 + 2,
net_depth=1,
net_width=128,
skip_layer=None,
output_init=functools.partial(
torch.nn.init.uniform_, b=1e-4
),
)
)
self.time_layers_1.append(
DenseLayer(
input_dim=self.time_encoder.latent_dim,
output_dim=64,
)
)
self.time_layers_2.append(
DenseLayer(
input_dim=self.time_encoder.latent_dim,
output_dim=64,
)
)
self.nerf = VanillaNeRFRadianceField()
def _warp(self, x, t_enc, i_layer):
uv, w = x[:, :2], x[:, 2:]
dw = self.warp_layers_1[i_layer](
torch.cat(
[self.posi_encoder_1(uv), self.time_layers_1[i_layer](t_enc)],
dim=-1,
)
)
w = w + dw
rt = self.warp_layers_2[i_layer](
torch.cat(
[self.posi_encoder_2(w), self.time_layers_2[i_layer](t_enc)],
dim=-1,
)
)
r = self._euler2rot_2dinv(rt[:, :1])
t = rt[:, 1:]
uv = torch.bmm(r, (uv - t)[..., None]).squeeze(-1)
return torch.cat([uv, w], dim=-1)
def warp(self, x, t):
t_enc = self.time_encoder(t)
x = self._warp(x, t_enc, 0)
x = x[..., [1, 2, 0]]
x = self._warp(x, t_enc, 1)
x = x[..., [2, 0, 1]]
x = self._warp(x, t_enc, 2)
return x
def query_opacity(self, x, timestamps, step_size):
idxs = torch.randint(0, len(timestamps), (x.shape[0],), device=x.device)
t = timestamps[idxs]
density = self.query_density(x, t)
# if the density is small enough those two are the same.
# opacity = 1.0 - torch.exp(-density * step_size)
opacity = density * step_size
return opacity
def query_density(self, x, t):
x = self.warp(x, t)
return self.nerf.query_density(x)
def forward(self, x, t, condition=None):
x = self.warp(x, t)
return self.nerf(x, condition=condition)
def _euler2rot_2dinv(self, euler_angle):
# (B, 1) -> (B, 2, 2)
theta = euler_angle.reshape(-1, 1, 1)
rot = torch.cat(
(
torch.cat((theta.cos(), -theta.sin()), 1),
torch.cat((theta.sin(), theta.cos()), 1),
),
2,
)
return rot
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment