Commit 873d050a authored by Ruilong Li's avatar Ruilong Li
Browse files

update performance with fixed random seed

parent 83ae4fd4
......@@ -16,5 +16,5 @@ Here the speed refers to the `iterations per second`.
| Model | Split | PSNR | Train Speed | Test Speed | GPU |
| - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 97 | 7.8 | V100 |
| ours | train (30K steps) | 34.26 | 96 | ? | TITAN RTX |
\ No newline at end of file
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 97 (310 sec) | 7.8 | V100 |
| ours | train (30K steps) | 34.42 | 94 (320 sec) | 6.1 | TITAN RTX |
\ No newline at end of file
......@@ -32,7 +32,7 @@ def render_image(radiance_field, rays, render_bkgd):
else:
num_rays, _ = rays_shape
results = []
chunk = torch.iinfo(torch.int32).max if radiance_field.training else 8192
chunk = torch.iinfo(torch.int32).max if radiance_field.training else 81920
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_color, chunk_depth, chunk_weight, alive_ray_mask, = volumetric_rendering(
......@@ -56,6 +56,7 @@ def render_image(radiance_field, rays, render_bkgd):
if __name__ == "__main__":
torch.manual_seed(42)
device = "cuda:0"
......@@ -80,7 +81,7 @@ if __name__ == "__main__":
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
num_workers=1,
num_workers=10,
batch_size=1,
collate_fn=getattr(train_dataset.__class__, "collate_fn"),
)
......@@ -112,6 +113,7 @@ if __name__ == "__main__":
occupancy values with shape (N, 1).
"""
density_after_activation = radiance_field.query_density(x)
# occupancy = 1.0 - torch.exp(-density_after_activation * render_step_size)
occupancy = density_after_activation * render_step_size
return occupancy
......@@ -174,5 +176,5 @@ if __name__ == "__main__":
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
# elapsed_time=312.5340702533722s | step=30000 | loss= 0.00025
# evaluation: psnr_avg=34.261171398162844 (4.12 it/s)
# elapsed_time=320.04s | step=30000 | loss= 0.00022
# evaluation: psnr_avg=34.41712421417236 (6.13 it/s)
from typing import List, Tuple
import torch
import torch.nn.functional as F
def query_grid(x: torch.Tensor, aabb: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
"""Query values in the grid field given the coordinates.
Args:
x: 2D / 3D coordinates, with shape of [..., 2 or 3]
aabb: 2D / 3D bounding box of the field, with shape of [4 or 6]
grid: Grid with shape [res_x, res_y, res_z, D] or [res_x, res_y, D]
Returns:
values with shape [..., D]
"""
output_shape = list(x.shape[:-1]) + [grid.shape[-1]]
if x.shape[-1] == 2 and aabb.shape == (4,) and grid.ndim == 3:
# 2D case
grid = grid.permute(2, 1, 0).unsqueeze(0) # [1, D, res_y, res_x]
x = (x.view(1, -1, 1, 2) - aabb[0:2]) / (aabb[2:4] - aabb[0:2])
elif x.shape[-1] == 3 and aabb.shape == (6,) and grid.ndim == 4:
# 3D case
grid = grid.permute(3, 2, 1, 0).unsqueeze(0) # [1, D, res_z, res_y, res_x]
x = (x.view(1, -1, 1, 1, 3) - aabb[0:3]) / (aabb[3:6] - aabb[0:3])
else:
raise ValueError(
"The shapes of the inputs do not match to either 2D case or 3D case! "
f"Got x: {x.shape}; aabb: {aabb.shape}; grid: {grid.shape}."
)
v = F.grid_sample(
grid,
x * 2.0 - 1.0,
align_corners=True,
padding_mode="zeros",
)
v = v.reshape(output_shape)
return v
def meshgrid(resolution: List[int]):
if len(resolution) == 2:
return meshgrid2d(resolution)
elif len(resolution) == 3:
return meshgrid3d(resolution)
else:
raise ValueError(resolution)
def meshgrid2d(res: Tuple[int, int], device: torch.device = "cpu"):
"""Create 2D grid coordinates.
Args:
res (Tuple[int, int]): resolutions for {x, y} dimensions.
Returns:
torch.long with shape (res[0], res[1], 2): dense 2D grid coordinates.
"""
return (
torch.stack(
torch.meshgrid(
[
torch.arange(res[0]),
torch.arange(res[1]),
],
indexing="ij",
),
dim=-1,
)
.long()
.to(device)
)
def meshgrid3d(res: Tuple[int, int, int], device: torch.device = "cpu"):
"""Create 3D grid coordinates.
Args:
res (Tuple[int, int, int]): resolutions for {x, y, z} dimensions.
Returns:
torch.long with shape (res[0], res[1], res[2], 3): dense 3D grid coordinates.
"""
return (
torch.stack(
torch.meshgrid(
[
torch.arange(res[0]),
torch.arange(res[1]),
torch.arange(res[2]),
],
indexing="ij",
),
dim=-1,
)
.long()
.to(device)
)
......@@ -3,7 +3,31 @@ from typing import Callable, List, Tuple, Union
import torch
from torch import nn
from ._grid import meshgrid
def meshgrid3d(res: Tuple[int, int, int], device: torch.device = "cpu"):
"""Create 3D grid coordinates.
Args:
res (Tuple[int, int, int]): resolutions for {x, y, z} dimensions.
Returns:
torch.long with shape (res[0], res[1], res[2], 3): dense 3D grid coordinates.
"""
return (
torch.stack(
torch.meshgrid(
[
torch.arange(res[0]),
torch.arange(res[1]),
torch.arange(res[2]),
],
indexing="ij",
),
dim=-1,
)
.long()
.to(device)
)
class OccupancyField(nn.Module):
......@@ -62,7 +86,7 @@ class OccupancyField(nn.Module):
self.register_buffer("occ_grid_mean", occ_grid_mean)
# Grid coords & indices
grid_coords = meshgrid(self.resolution).reshape(self.num_cells, self.num_dim)
grid_coords = meshgrid3d(self.resolution).reshape(self.num_cells, self.num_dim)
self.register_buffer("grid_coords", grid_coords)
grid_indices = torch.arange(self.num_cells)
self.register_buffer("grid_indices", grid_indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment