Commit 31422712 authored by Ruilong Li's avatar Ruilong Li
Browse files

add examples

parent b33d0d10
# Copyright (c) Meta Platforms, Inc. and affiliates.
import math
import torch
class CachedIterDataset(torch.utils.data.IterableDataset):
def __init__(
self,
training: bool = False,
cache_n_repeat: int = 0,
):
self.training = training
self.cache_n_repeat = cache_n_repeat
self._cache = None
self._n_repeat = 0
def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches)."""
raise NotImplementedError
def preprocess(self, data):
"""Process the fetched / cached data with randomness."""
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = 0
iter_end = self.__len__()
else: # in a worker process
# split workload
per_worker = int(math.ceil(self.__len__() / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = worker_id * per_worker
iter_end = min(iter_start + per_worker, self.__len__())
if self.training:
while True:
for index in iter_start + torch.randperm(iter_end - iter_start):
yield self.__getitem__(index)
else:
for index in range(iter_start, iter_end):
yield self.__getitem__(index)
def __getitem__(self, index):
if (
self.training
and (self._cache is not None)
and (self._n_repeat < self.cache_n_repeat)
):
data = self._cache
self._n_repeat += 1
else:
data = self.fetch_data(index)
self._cache = data
self._n_repeat = 1
return self.preprocess(data)
@classmethod
def collate_fn(cls, batch):
return batch[0]
# Copyright (c) Meta Platforms, Inc. and affiliates.
import json
import os
import cv2
import imageio.v2 as imageio
import numpy as np
import torch
from .base import CachedIterDataset
from .utils import Cameras, generate_rays, transform_cameras
def _load_renderings(root_fp: str, subject_id: str, split: str):
"""Load images from disk."""
if not root_fp.startswith("/"):
# allow relative path. e.g., "./data/nerf_synthetic/"
root_fp = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"..",
root_fp,
)
data_dir = os.path.join(root_fp, subject_id)
with open(os.path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp:
meta = json.load(fp)
images = []
camtoworlds = []
for i in range(len(meta["frames"])):
frame = meta["frames"][i]
fname = os.path.join(data_dir, frame["file_path"] + ".png")
rgba = imageio.imread(fname)
camtoworlds.append(frame["transform_matrix"])
images.append(rgba)
images = np.stack(images, axis=0).astype(np.float32)
camtoworlds = np.stack(camtoworlds, axis=0).astype(np.float32)
h, w = images.shape[1:3]
camera_angle_x = float(meta["camera_angle_x"])
focal = 0.5 * w / np.tan(0.5 * camera_angle_x)
return images, camtoworlds, focal
class SubjectLoader(CachedIterDataset):
"""Single subject data loader for training and evaluation."""
SPLITS = ["train", "val", "test"]
SUBJECT_IDS = [
"chair",
"drums",
"ficus",
"hotdog",
"lego",
"materials",
"mic",
]
WIDTH, HEIGHT = 800, 800
NEAR, FAR = 2.0, 6.0
def __init__(
self,
subject_id: str,
root_fp: str,
split: str,
resize_factor: float = 1.0,
color_bkgd_aug: str = "white",
num_rays: int = None,
cache_n_repeat: int = 0,
near: float = None,
far: float = None,
):
assert split in self.SPLITS, "%s" % split
assert subject_id in self.SUBJECT_IDS, "%s" % subject_id
assert color_bkgd_aug in ["white", "black", "random"]
self.resize_factor = resize_factor
self.split = split
self.num_rays = num_rays
self.near = self.NEAR if near is None else near
self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (split in ["train"])
self.color_bkgd_aug = color_bkgd_aug
self.images, self.camtoworlds, self.focal = _load_renderings(
root_fp, subject_id, split
)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)
super().__init__(self.training, cache_n_repeat)
def __len__(self):
return len(self.images)
# @profile
def preprocess(self, data):
"""Process the fetched / cached data with randomness."""
rgba, rays = data["rgba"], data["rays"]
pixels, alpha = torch.split(rgba, [3, 1], dim=-1)
if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3)
elif self.color_bkgd_aug == "black":
color_bkgd = torch.zeros(3)
else:
# just use white during inference
color_bkgd = torch.ones(3)
pixels = pixels * alpha + color_bkgd * (1.0 - alpha)
return {
"pixels": pixels, # [n_rays, 3] or [h, w, 3]
"rays": rays, # [n_rays,] or [h, w]
"color_bkgd": color_bkgd, # [3,]
**{k: v for k, v in data.items() if k not in ["rgba", "rays"]},
}
def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches)."""
# load data
camera_id = index
K = np.array(
[
[self.focal, 0, self.WIDTH / 2.0],
[0, self.focal, self.HEIGHT / 2.0],
[0, 0, 1],
]
).astype(np.float32)
w2c = np.linalg.inv(self.camtoworlds[camera_id])
rgba = self.images[camera_id]
# create pixels
rgba = (
torch.from_numpy(
cv2.resize(
rgba,
(0, 0),
fx=self.resize_factor,
fy=self.resize_factor,
interpolation=cv2.INTER_AREA,
)
).float()
/ 255.0
)
# create rays from camera
cameras = Cameras(
intrins=torch.from_numpy(K).float(),
extrins=torch.from_numpy(w2c).float(),
distorts=None,
width=self.WIDTH,
height=self.HEIGHT,
)
cameras = transform_cameras(cameras, self.resize_factor)
if self.num_rays is not None:
x = torch.randint(0, self.WIDTH, size=(self.num_rays,))
y = torch.randint(0, self.HEIGHT, size=(self.num_rays,))
pixels_xy = torch.stack([x, y], dim=-1)
rgba = rgba[y, x, :]
else:
pixels_xy = None # full image
# Be careful: This dataset's camera coordinate is not the same as
# opencv's camera coordinate! It is actually opengl.
rays = generate_rays(
cameras,
opencv_format=False,
near=self.near,
far=self.far,
pixels_xy=pixels_xy,
)
return {
"camera_id": camera_id,
"rgba": rgba, # [h, w, 4] or [num_rays, 4]
"rays": rays, # [h, w] or [num_rays, 4]
}
# Copyright (c) Meta Platforms, Inc. and affiliates.
import collections
import math
import torch
import torch.nn.functional as F
Rays = collections.namedtuple(
"Rays", ("origins", "directions", "viewdirs", "radii", "near", "far")
)
Cameras = collections.namedtuple(
"Cameras", ("intrins", "extrins", "distorts", "width", "height")
)
def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*(None if x is None else fn(x) for x in tup))
def homo(points: torch.Tensor) -> torch.Tensor:
"""Get the homogeneous coordinates."""
return F.pad(points, (0, 1), value=1)
def transform_cameras(cameras: Cameras, resize_factor: float) -> torch.Tensor:
intrins = cameras.intrins
intrins[..., :2, :] = intrins[..., :2, :] * resize_factor
width = int(cameras.width * resize_factor + 0.5)
height = int(cameras.height * resize_factor + 0.5)
return Cameras(
intrins=intrins,
extrins=cameras.extrins,
distorts=cameras.distorts,
width=width,
height=height,
)
def generate_rays(
cameras: Cameras,
opencv_format: bool = True,
near: float = None,
far: float = None,
pixels_xy: torch.Tensor = None,
) -> Rays:
"""Generating rays for a single or multiple cameras.
:params cameras [(n_cams,)]
:returns: Rays
[(n_cams,) height, width] if pixels_xy is None
[(n_cams,) num_pixels] if pixels_xy is given
"""
if pixels_xy is not None:
K = cameras.intrins[..., None, :, :]
c2w = cameras.extrins[..., None, :, :].inverse()
x, y = pixels_xy[..., 0], pixels_xy[..., 1]
else:
K = cameras.intrins[..., None, None, :, :]
c2w = cameras.extrins[..., None, None, :, :].inverse()
x, y = torch.meshgrid(
torch.arange(cameras.width, dtype=K.dtype),
torch.arange(cameras.height, dtype=K.dtype),
indexing="xy",
) # [height, width]
camera_dirs = homo(
torch.stack(
[
(x - K[..., 0, 2] + 0.5) / K[..., 0, 0],
(y - K[..., 1, 2] + 0.5) / K[..., 1, 1],
],
dim=-1,
)
) # [n_cams, height, width, 3]
if not opencv_format:
camera_dirs[..., [1, 2]] *= -1
# [n_cams, height, width, 3]
directions = (camera_dirs[..., None, :] * c2w[..., :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[..., :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
if pixels_xy is None:
# Distance from each unit-norm direction vector to its x-axis neighbor.
dx = torch.sqrt(
torch.sum(
(directions[..., :-1, :, :] - directions[..., 1:, :, :]) ** 2,
dim=-1,
)
)
dx = torch.cat([dx, dx[..., -2:-1, :]], dim=-2)
radii = dx[..., None] * 2 / math.sqrt(12) # [n_cams, height, width, 1]
else:
radii = None
if near is not None:
near = near * torch.ones_like(origins[..., 0:1])
if far is not None:
far = far * torch.ones_like(origins[..., 0:1])
rays = Rays(
origins=origins, # [n_cams, height, width, 3]
directions=directions, # [n_cams, height, width, 3]
viewdirs=viewdirs, # [n_cams, height, width, 3]
radii=radii, # [n_cams, height, width, 1]
# near far is not needed when they are estimated by skeleton.
near=near,
far=far,
)
return rays
from abc import abstractmethod
from typing import Tuple
import torch
import torch.nn as nn
class BaseRadianceField(nn.Module):
"""An abstract RadianceField class (supports both 2D and 3D).
The key functions to be implemented are:
- forward(positions, directions, masks): returns rgb and density.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__()
@abstractmethod
def forward(
self,
positions: torch.Tensor,
directions: torch.Tensor = None,
masks: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns {rgb, density}."""
raise NotImplementedError()
@abstractmethod
def query_density(self, positions: torch.Tensor) -> torch.Tensor:
"""Returns {density}."""
raise NotImplementedError()
from typing import Callable, List, Union
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
try:
import tinycudann as tcnn
except ImportError:
print(
"Please install tinycudann by: "
"pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch"
)
exit()
from .base import BaseRadianceField
class NGPradianceField(BaseRadianceField):
"""Instance-NGP radiance Field"""
class _TruncExp(Function): # pylint: disable=abstract-method
# Implementation from torch-ngp:
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, x): # pylint: disable=arguments-differ
ctx.save_for_backward(x)
return torch.exp(x)
@staticmethod
@custom_bwd
def backward(ctx, g): # pylint: disable=arguments-differ
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(-15, 15))
trunc_exp = _TruncExp.apply
def __init__(
self,
aabb: Union[torch.Tensor, List[float]],
num_dim: int = 3,
use_viewdirs: bool = True,
density_activation: Callable = trunc_exp,
) -> None:
super().__init__()
if not isinstance(aabb, torch.Tensor):
aabb = torch.tensor(aabb, dtype=torch.float32)
self.register_buffer("aabb", aabb)
self.num_dim = num_dim
self.use_viewdirs = use_viewdirs
self.density_activation = density_activation
self.geo_feat_dim = 15
per_level_scale = 1.4472692012786865
if self.use_viewdirs:
self.direction_encoding = tcnn.Encoding(
n_input_dims=num_dim,
encoding_config={
"otype": "SphericalHarmonics",
"degree": 4,
},
)
self.mlp_base = tcnn.NetworkWithInputEncoding(
n_input_dims=num_dim,
n_output_dims=1 + self.geo_feat_dim,
encoding_config={
"otype": "HashGrid",
"n_levels": 16,
"n_features_per_level": 2,
"log2_hashmap_size": 19,
"base_resolution": 16,
"per_level_scale": per_level_scale,
},
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "None",
"n_neurons": 64,
"n_hidden_layers": 1,
},
)
self.mlp_head = tcnn.Network(
n_input_dims=(
(self.direction_encoding.n_output_dims if self.use_viewdirs else 0)
+ self.geo_feat_dim
),
n_output_dims=3,
network_config={
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "Sigmoid",
"n_neurons": 64,
"n_hidden_layers": 2,
},
)
@torch.cuda.amp.autocast()
def query_density(self, x, return_feat: bool = False):
bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
x = (x - bb_min) / (bb_max - bb_min)
selector = ((x > 0.0) & (x < 1.0)).all(dim=-1)
x = (
self.mlp_base(x.view(-1, self.num_dim))
.view(list(x.shape[:-1]) + [1 + self.geo_feat_dim])
.to(x)
)
density_before_activation, base_mlp_out = torch.split(
x, [1, self.geo_feat_dim], dim=-1
)
density = (
self.density_activation(density_before_activation) * selector[..., None]
)
if return_feat:
return density, base_mlp_out
else:
return density
@torch.cuda.amp.autocast()
def _query_rgb(self, dir, embedding):
# tcnn requires directions in the range [0, 1]
if self.use_viewdirs:
dir = (dir + 1.0) / 2.0
d = self.direction_encoding(dir.view(-1, dir.shape[-1]))
h = torch.cat([d, embedding.view(-1, self.geo_feat_dim)], dim=-1)
else:
h = embedding.view(-1, self.geo_feat_dim)
rgb = self.mlp_head(h).view(list(embedding.shape[:-1]) + [3]).to(embedding)
return rgb
@torch.cuda.amp.autocast()
def forward(
self,
positions: torch.Tensor,
directions: torch.Tensor = None,
mask: torch.Tensor = None,
):
if self.use_viewdirs and (directions is not None):
assert (
positions.shape == directions.shape
), f"{positions.shape} v.s. {directions.shape}"
if mask is not None:
density = torch.zeros_like(positions[..., :1])
rgb = torch.zeros(list(positions.shape[:-1]) + [3], device=positions.device)
density[mask], embedding = self.query_density(positions[mask])
rgb[mask] = self.query_rgb(
directions[mask] if directions is not None else None,
embedding=embedding,
)
else:
density, embedding = self.query_density(positions, return_feat=True)
rgb = self._query_rgb(directions, embedding=embedding)
return rgb, density
git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
opencv-python
imageio
numpy
tqdm
\ No newline at end of file
import math
import time
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from datasets.nerf_synthetic import SubjectLoader
from datasets.utils import namedtuple_map
from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering
def render_image(radiance_field, rays, render_bkgd, chunk=8192):
"""Render the pixels of an image.
Args:
radiance_field: the radiance field of nerf.
rays: a `Rays` namedtuple, the rays to be rendered.
chunk: int, the size of chunks to render sequentially.
Returns:
rgb: torch.tensor, rendered color image.
depth: torch.tensor, rendered depth image.
acc: torch.tensor, rendered accumulated weights per pixel.
"""
rays_shape = rays.origins.shape
if len(rays_shape) == 3:
height, width, _ = rays_shape
num_rays = height * width
rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays)
else:
num_rays, _ = rays_shape
results = []
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_color, chunk_depth, chunk_weight, _, = volumetric_rendering(
query_fn=radiance_field.forward, # {x, dir} -> {rgb, density}
rays_o=chunk_rays.origins,
rays_d=chunk_rays.viewdirs,
scene_aabb=occ_field.aabb,
scene_occ_binary=occ_field.occ_grid_binary,
scene_resolution=occ_field.resolution,
render_bkgd=render_bkgd,
render_n_samples=render_n_samples,
)
results.append([chunk_color, chunk_depth, chunk_weight])
rgb, depth, acc = [torch.cat(r, dim=0) for r in zip(*results)]
return (
rgb.view((*rays_shape[:-1], -1)),
depth.view((*rays_shape[:-1], -1)),
acc.view((*rays_shape[:-1], -1)),
)
if __name__ == "__main__":
device = "cuda:0"
# setup dataset
train_dataset = SubjectLoader(
subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="train",
num_rays=8192,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
num_workers=10,
batch_size=1,
collate_fn=getattr(train_dataset.__class__, "collate_fn"),
)
val_dataset = SubjectLoader(
subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="val",
num_rays=None,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
num_workers=1,
batch_size=1,
collate_fn=getattr(train_dataset.__class__, "collate_fn"),
)
# setup the scene bounding box.
scene_aabb = torch.tensor([-1.5, -1.5, -1.5, 1.5, 1.5, 1.5])
# setup the scene radiance field. Assume you have a NeRF model and
# it has following functions:
# - query_density(): {x} -> {density}
# - forward(): {x, dirs} -> {rgb, density}
radiance_field = NGPradianceField(aabb=scene_aabb).to(device)
# setup some rendering settings
render_n_samples = 1024
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=3e-3, eps=1e-15)
# setup occupancy field with eval function
def occ_eval_fn(x: torch.Tensor) -> torch.Tensor:
"""Evaluate occupancy given positions.
Args:
x: positions with shape (N, 3).
Returns:
occupancy values with shape (N, 1).
"""
density_after_activation = radiance_field.query_density(x)
occupancy = density_after_activation * render_step_size
return occupancy
occ_field = OccupancyField(
occ_eval_fn=occ_eval_fn, aabb=scene_aabb, resolution=128
).to(device)
# training
step = 0
tic = time.time()
for epoch in range(100):
for data in train_dataloader:
step += 1
if step > 30_000:
print("training stops")
exit()
# generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device)
# update occupancy grid
occ_field.every_n_step(step)
rgb, depth, acc = render_image(radiance_field, rays, render_bkgd)
# compute loss
loss = F.mse_loss(rgb, pixels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
elapsed_time = time.time() - tic
print(
f"elapsed_time={elapsed_time:.2f}s | {step=} | loss={loss.item(): .5f}"
)
if step % 30_000 == 0 and step > 0:
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for data in tqdm.tqdm(val_dataloader):
# generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device)
# rendering
rgb, depth, acc = render_image(
radiance_field, rays, render_bkgd
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
# elapsed_time=312.5340702533722s | step=30000 | loss= 0.00025
# evaluation: psnr_avg=34.261171398162844 (4.12 it/s)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment