Commit e6647a00 authored by Ruilong Li's avatar Ruilong Li
Browse files

a big cleanup: data fix seed; ngp cubic box

parent 6ab97aeb
......@@ -258,6 +258,8 @@ class SubjectLoader(torch.utils.data.Dataset):
)
self.K = torch.tensor(self.K).to(torch.float32).to(device)
self.height, self.width = self.images.shape[1:3]
self.g = torch.Generator(device=device)
self.g.manual_seed(42)
def __len__(self):
return len(self.images)
......@@ -274,7 +276,7 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device)
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
......@@ -304,14 +306,15 @@ class SubjectLoader(torch.utils.data.Dataset):
len(self.images),
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
else:
image_id = [index] * num_rays
x = torch.randint(
0, self.width, size=(num_rays,), device=self.images.device
0, self.width, size=(num_rays,), device=self.images.device, generator=self.g
)
y = torch.randint(
0, self.height, size=(num_rays,), device=self.images.device
0, self.height, size=(num_rays,), device=self.images.device, generator=self.g
)
else:
image_id = [index]
......
......@@ -124,6 +124,8 @@ class SubjectLoader(torch.utils.data.Dataset):
self.camtoworlds = self.camtoworlds.to(device)
self.K = self.K.to(device)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
self.g = torch.Generator(device=device)
self.g.manual_seed(42)
def __len__(self):
return len(self.images)
......@@ -141,7 +143,7 @@ class SubjectLoader(torch.utils.data.Dataset):
if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device)
color_bkgd = torch.rand(3, device=self.images.device, generator=self.g)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
......@@ -172,14 +174,15 @@ class SubjectLoader(torch.utils.data.Dataset):
len(self.images),
size=(num_rays,),
device=self.images.device,
generator=self.g,
)
else:
image_id = [index] * num_rays
x = torch.randint(
0, self.WIDTH, size=(num_rays,), device=self.images.device
0, self.WIDTH, size=(num_rays,), device=self.images.device, generator=self.g
)
y = torch.randint(
0, self.HEIGHT, size=(num_rays,), device=self.images.device
0, self.HEIGHT, size=(num_rays,), device=self.images.device, generator=self.g
)
else:
image_id = [index]
......
......@@ -85,6 +85,13 @@ class NGPRadianceField(torch.nn.Module):
super().__init__()
if not isinstance(aabb, torch.Tensor):
aabb = torch.tensor(aabb, dtype=torch.float32)
# Turns out rectangle aabb will leads to uneven collision so bad performance.
# We enforce a cube aabb here.
center = (aabb[..., :num_dim] + aabb[..., num_dim:]) / 2.0
size = (aabb[..., num_dim:] - aabb[..., :num_dim]).max()
aabb = torch.cat([center - size / 2.0, center + size / 2.0], dim=-1)
self.register_buffer("aabb", aabb)
self.num_dim = num_dim
self.use_viewdirs = use_viewdirs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment