Commit 0c7b5320 authored by Ruilong Li's avatar Ruilong Li
Browse files

update tracexp

parent 09669d2a
...@@ -23,6 +23,7 @@ Tested with the default settings on the Lego test set. ...@@ -23,6 +23,7 @@ Tested with the default settings on the Lego test set.
| - | - | - | - | - | - | - | | - | - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 | | instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| instant-ngp (code) | train (35k steps) | 36.08 | 308 sec | 55.32 fps | TITAN RTX | 1734MB | | instant-ngp (code) | train (35k steps) | 36.08 | 308 sec | 55.32 fps | TITAN RTX | 1734MB |
| instant-ngp (code) w/o rng bkgd| train (35k steps) | 34.17 | - | - | - | - |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 | | torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | trainval (35K steps) | 36.22 | 378 sec | 12.08 fps | TITAN RTX | | ours | trainval (35K steps) | 36.22 | 378 sec | 12.08 fps | TITAN RTX |
......
...@@ -6,8 +6,9 @@ from torch.cuda.amp import custom_bwd, custom_fwd ...@@ -6,8 +6,9 @@ from torch.cuda.amp import custom_bwd, custom_fwd
try: try:
import tinycudann as tcnn import tinycudann as tcnn
except ImportError: except ImportError as e:
print( print(
f"Error: {e}! "
"Please install tinycudann by: " "Please install tinycudann by: "
"pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch" "pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch"
) )
...@@ -16,32 +17,35 @@ except ImportError: ...@@ -16,32 +17,35 @@ except ImportError:
from .base import BaseRadianceField from .base import BaseRadianceField
class NGPradianceField(BaseRadianceField): class _TruncExp(Function): # pylint: disable=abstract-method
"""Instance-NGP radiance Field""" # Implementation from torch-ngp:
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, x): # pylint: disable=arguments-differ
ctx.save_for_backward(x)
return torch.exp(x)
@staticmethod
@custom_bwd
def backward(ctx, g): # pylint: disable=arguments-differ
x = ctx.saved_tensors[0]
return g * torch.exp(torch.clamp(x, max=15))
class _TruncExp(Function): # pylint: disable=abstract-method
# Implementation from torch-ngp:
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, x): # pylint: disable=arguments-differ
ctx.save_for_backward(x)
return torch.exp(x)
@staticmethod trunc_exp = _TruncExp.apply
@custom_bwd
def backward(ctx, g): # pylint: disable=arguments-differ
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(-15, 15))
trunc_exp = _TruncExp.apply
class NGPradianceField(BaseRadianceField):
"""Instance-NGP radiance Field"""
def __init__( def __init__(
self, self,
aabb: Union[torch.Tensor, List[float]], aabb: Union[torch.Tensor, List[float]],
num_dim: int = 3, num_dim: int = 3,
use_viewdirs: bool = True, use_viewdirs: bool = True,
density_activation: Callable = trunc_exp, density_activation: Callable = lambda x: trunc_exp(x - 1),
# density_activation: Callable = lambda x: torch.nn.functional.softplus(x - 1),
) -> None: ) -> None:
super().__init__() super().__init__()
if not isinstance(aabb, torch.Tensor): if not isinstance(aabb, torch.Tensor):
......
import math import math
import time import time
import imageio
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -10,28 +11,7 @@ from radiance_fields.ngp import NGPradianceField ...@@ -10,28 +11,7 @@ from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering from nerfacc import OccupancyField, volumetric_rendering
TARGET_SAMPLE_BATCH_SIZE = 1 << 16 TARGET_SAMPLE_BATCH_SIZE = 1 << 18
# import tqdm
# device = "cuda:0"
# radiance_field = NGPradianceField(aabb=[0, 0, 0, 1, 1, 1]).to(device)
# positions = torch.rand((TARGET_SAMPLE_BATCH_SIZE, 3), device=device)
# directions = torch.rand(positions.shape, device=device)
# optimizer = torch.optim.Adam(
# radiance_field.parameters(),
# lr=1e-10,
# # betas=(0.9, 0.99),
# eps=1e-15,
# # weight_decay=1e-6,
# )
# for _ in tqdm.tqdm(range(1000)):
# rgbs, sigmas = radiance_field(positions, directions)
# loss = rgbs.mean()
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# exit()
def render_image(radiance_field, rays, render_bkgd, render_step_size): def render_image(radiance_field, rays, render_bkgd, render_step_size):
...@@ -91,8 +71,9 @@ if __name__ == "__main__": ...@@ -91,8 +71,9 @@ if __name__ == "__main__":
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id=scene, subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp="/home/ruilongli/data/nerf_synthetic/",
split="trainval", split="train",
num_rays=1024, num_rays=1024,
# color_bkgd_aug="random",
) )
train_dataset.images = train_dataset.images.to(device) train_dataset.images = train_dataset.images.to(device)
...@@ -139,12 +120,12 @@ if __name__ == "__main__": ...@@ -139,12 +120,12 @@ if __name__ == "__main__":
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
radiance_field.parameters(), radiance_field.parameters(),
lr=1e-2, lr=1e-2,
# betas=(0.9, 0.99), betas=(0.9, 0.99),
eps=1e-15, eps=1e-15,
# weight_decay=1e-6, weight_decay=1e-6,
) )
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[20000, 30000], gamma=0.1 optimizer, milestones=[10000, 15000, 18000], gamma=0.33
) )
# setup occupancy field with eval function # setup occupancy field with eval function
...@@ -172,8 +153,11 @@ if __name__ == "__main__": ...@@ -172,8 +153,11 @@ if __name__ == "__main__":
data_time = 0 data_time = 0
tic_data = time.time() tic_data = time.time()
# Scaling up the gradients for Adam
grad_scaler = torch.cuda.amp.GradScaler(2**10)
for epoch in range(10000000): for epoch in range(10000000):
for i in range(len(train_dataset)): for i in range(len(train_dataset)):
radiance_field.train()
data = train_dataset[i] data = train_dataset[i]
data_time += time.time() - tic_data data_time += time.time() - tic_data
...@@ -198,27 +182,29 @@ if __name__ == "__main__": ...@@ -198,27 +182,29 @@ if __name__ == "__main__":
alive_ray_mask = acc.squeeze(-1) > 0 alive_ray_mask = acc.squeeze(-1) > 0
# compute loss # compute loss
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad() optimizer.zero_grad()
(loss * 128).backward() # do not unscale it because we are using Adam.
grad_scaler.scale(loss).backward()
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
if step % 100 == 0: if step % 100 == 0:
elapsed_time = time.time() - tic elapsed_time = time.time() - tic
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
print( print(
f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | " f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | "
f"loss={loss:.5f} | " f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | " f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} " f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} |"
) )
# if time.time() - tic > 300: # if time.time() - tic > 300:
if step == 35_000: if step >= 5_000 and step % 5000 == 0 and step > 0:
print("training stops")
# evaluation # evaluation
radiance_field.eval() radiance_field.eval()
psnrs = [] psnrs = []
with torch.no_grad(): with torch.no_grad():
for data in tqdm.tqdm(test_dataloader): for data in tqdm.tqdm(test_dataloader):
...@@ -235,6 +221,41 @@ if __name__ == "__main__": ...@@ -235,6 +221,41 @@ if __name__ == "__main__":
psnrs.append(psnr.item()) psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs) psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}") print(f"evaluation: {psnr_avg=}")
imageio.imwrite(
"acc_binary_test.png",
((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
)
psnrs = []
train_dataset.training = False
with torch.no_grad():
for data in tqdm.tqdm(train_dataloader):
# generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device)
# rendering
rgb, depth, acc, _, _ = render_image(
radiance_field, rays, render_bkgd, render_step_size
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation on train: {psnr_avg=}")
imageio.imwrite(
"acc_binary_train.png",
((acc > 0).float().cpu().numpy() * 255).astype(np.uint8),
)
print("acc", acc[acc > 0].min())
imageio.imwrite(
"rgb_train.png",
(rgb.cpu().numpy() * 255).astype(np.uint8),
)
train_dataset.training = True
if step == 20_000:
print("training stops")
exit() exit()
tic_data = time.time() tic_data = time.time()
......
...@@ -2,6 +2,7 @@ from typing import Callable, List, Tuple, Union ...@@ -2,6 +2,7 @@ from typing import Callable, List, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from torch_scatter import scatter_max
def meshgrid3d(res: Tuple[int, int, int], device: torch.device = "cpu"): def meshgrid3d(res: Tuple[int, int, int], device: torch.device = "cpu"):
...@@ -144,7 +145,8 @@ class OccupancyField(nn.Module): ...@@ -144,7 +145,8 @@ class OccupancyField(nn.Module):
) / self.resolution_tensor ) / self.resolution_tensor
bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0) bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
x = x * (bb_max - bb_min) + bb_min x = x * (bb_max - bb_min) + bb_min
tmp_occ_grid[indices] = self.occ_eval_fn(x).squeeze(-1) tmp_occ = self.occ_eval_fn(x).squeeze(-1)
tmp_occ_grid, _ = scatter_max(tmp_occ, indices, dim=0, out=tmp_occ_grid)
# ema update # ema update
ema_mask = (self.occ_grid >= 0) & (tmp_occ_grid >= 0) ema_mask = (self.occ_grid >= 0) & (tmp_occ_grid >= 0)
......
...@@ -95,8 +95,12 @@ def volumetric_rendering( ...@@ -95,8 +95,12 @@ def volumetric_rendering(
(compact_frustum_starts + compact_frustum_ends) / 2.0, (compact_frustum_starts + compact_frustum_ends) / 2.0,
n_rays, n_rays,
) )
# TODO: use transmittance to compose bkgd color:
# https://github.com/NVlabs/instant-ngp/blob/14d6ba6fa899e9f069d2f65d33dbe3cd43056ddd/src/testbed_nerf.cu#L1400
# accumulated_color = linear_to_srgb(accumulated_color)
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight) accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
# accumulated_color = srgb_to_linear(accumulated_color)
return ( return (
accumulated_color, accumulated_color,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment