Unverified Commit 8dcfbad9 authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Reformat (#31)



* seems working

* contraction func in cuda

* Update type

* More type updates

* disable DDA for contraction

* update contraction perfom in readme

* 360 data: Garden

* eval at max_steps

* add perform of 360 to readme

* fix contraction scaling

* tiny hot fix

* new volrend

* cleanup ray_marching.cu

* cleanup backend

* tests

* cleaning up Grid

* fix doc for grid base class

* check and fix for contraction

* test grid

* rendering and marching

* transmittance_compress verified

* rendering is indeed faster

* pipeline is working

* lego example

* cleanup

* cuda folder is cleaned up! finally!

* cuda formatting

* contraction verify

* upgrade grid

* test for ray marching

* pipeline

* ngp with contraction

* train_ngp runs but slow

* trasmittance seperate to two. Now NGP is as fast as before

* verified faster than before

* bug fix for contraction

* ngp contraction fix

* tiny cleanup

* contraction works! yay!

* contraction with tanh seems working

* minor update

* support alpha rendering

* absorb visibility to ray marching

* tiny import update

* get rid of contraction temperture;

* doc for ContractionType

* doc for Grid

* doc for grid.py is done

* doc for ray marching

* rendering function

* fix doc for rendering

* doc for vol rend

* autosummary for utils

* fix autosummary line break

* utils docs

* api doc is done

* starting work on examples

* contraction for npg is in python now

* further clean up examples

* mlp nerf is running

* dnerf is in

* update readme command

* merge

* disable pylint error for now

* reformatting and skip tests without cuda

* fix the type issue for contractiontype

* fix cuda attribute issue

* bump to 0.1.0
Co-authored-by: default avatarMatt Tancik <tancik@berkeley.edu>
parent a7611603
...@@ -27,9 +27,9 @@ jobs: ...@@ -27,9 +27,9 @@ jobs:
run: | run: |
pip install --upgrade --upgrade-strategy eager -e .[dev] pip install --upgrade --upgrade-strategy eager -e .[dev]
- name: Run isort - name: Run isort
run: isort docs/ nerfacc/ scripts/ tests/ --profile black --check run: isort docs/ nerfacc/ scripts/ examples/ tests/ --profile black --skip examples/pycolmap --check
- name: Run Black - name: Run Black
run: black docs/ nerfacc/ scripts/ tests/ --check run: black docs/ nerfacc/ scripts/ examples/ tests/ --exclude examples/pycolmap --check
- name: Python Pylint # - name: Python Pylint
run: | # run: |
pylint nerfacc tests scripts # pylint nerfacc/ tests/ scripts/ examples/
[submodule "examples/pycolmap"]
path = examples/pycolmap
url = https://github.com/rmbrualla/pycolmap.git
\ No newline at end of file
...@@ -7,70 +7,55 @@ This is a **tiny** tootlbox for **accelerating** NeRF training & rendering usin ...@@ -7,70 +7,55 @@ This is a **tiny** tootlbox for **accelerating** NeRF training & rendering usin
## Examples: Instant-NGP NeRF ## Examples: Instant-NGP NeRF
``` bash ``` bash
python examples/trainval.py ngp --train_split trainval python examples/train_ngp_nerf.py --train_split trainval --scene lego
``` ```
Performance on TITAN RTX : Performance:
| trainval | Lego | Mic | Materials | Chair | Hotdog | | PSNR | Lego | Mic | Materials | Chair | Hotdog |
| - | - | - | - | - | - | | - | - | - | - | - | - |
| Time | 300s | 274s | 266s | 341s | 277s | | Papers (5mins) | 36.39 | 36.22 | 29.78 | 35.00 | 37.40 |
| PSNR | 36.61 | 37.62 | 30.11 | 36.09 | 38.09 | | Ours (~5mins) | 36.61 | 37.62 | 30.11 | 36.09 | 38.09 |
| FPS | 12.87 | 23.67 | 9.33 | 16.91 | 7.48 | | Exact training time | 300s | 274s | 266s | 341s | 277s |
Instant-NGP paper (5 min) on 3090 (w/ mask):
| trainval | Lego | Mic | Materials | Chair | Hotdog |
| - | - | - | - | - | - |
| PSNR | 36.39 | 36.22 | 29.78 | 35.00 | 37.40 |
## Examples: Vanilla MLP NeRF ## Examples: Vanilla MLP NeRF
``` bash ``` bash
python examples/trainval.py vanilla --train_split train python examples/train_mlp_nerf.py --train_split train --scene lego
``` ```
Performance on test set: Performance:
| | Lego | Mic | Materials | Chair | Hotdog | | PNSR | Lego | Mic | Materials | Chair | Hotdog |
| - | - | - | - | - | - | | - | - | - | - | - | - |
| Paper PSNR (train set) | 32.54 | 32.91 | 29.62 | 33.00 | 36.18 | | Paper (~2days) | 32.54 | 32.91 | 29.62 | 33.00 | 36.18 |
| Our PSNR (train set) | 33.21 | 33.36 | 29.48 | 32.79 | 35.54 | | Ours (~45mins) | 33.21 | 33.36 | 29.48 | 32.79 | 35.54 |
| Our PSNR (trainval set) | 33.66 | - | - | - | - | - |
| Our train time & test FPS | 45min; 0.43FPS | 44min; 1FPS | 37min; 0.33FPS* | 44min; 0.57FPS* | 50min; 0.15 FPS* |
For reference, vanilla NeRF paper trains on V100 GPU for 1-2 days per scene. Test time rendering takes about 30 secs to render a 800x800 image. Our model is trained on a TITAN X.
Note: We only use a single MLP with more samples (1024), instead of two MLPs with coarse-to-fine sampling as in the paper. Both ways share the same spirit to do dense sampling around the surface. Our fast rendering inheritly skip samples away from the surface so we can simplly increase the number of samples with a single MLP, to achieve the same goal with coarse-to-fine sampling, without runtime or memory issue.
*FPS for some scenes are tested under `--test_chunk_size=8192` (default is `81920`) to avoid OOM.
## Examples: MLP NeRF on Dynamic objects (D-NeRF)
## Examples: MLP NeRF on Dynamic objects ```bash
python examples/train_mlp_dnerf.py --train_split train --scene lego
Here we trained something similar to D-NeRF on the dnerf dataset:
``` bash
python examples/trainval.py dnerf --train_split train --test_chunk_size=8192
``` ```
Performance on test set: Performance:
| | Lego | Stand Up | | | Lego | Stand Up |
| - | - | - | | - | - | - |
| DNeRF paper PSNR (train set) | 21.64 | 32.79 | | Paper (~2days) | 21.64 | 32.79 |
| Our PSNR (train set) | 24.66 | 33.98 | | Ours (~45mins) | 24.66 | 33.98 |
| Our train time & test FPS | 43min; 0.15FPS | 41min; 0.4FPS |
Note the numbers here are tested with version `v0.0.8` ## Examples: NGP on unbounded scene
On MipNeRF360 Garden scene
```bash
python examples/train_ngp_nerf.py --train_split train --scene garden --aabb="-4,-4,-4,4,4,4" --unbounded --cone_angle=0.004
```
<!-- Performance:
## Tips:
1. sample rays over all images per iteration (`batch_over_images=True`) is better: `PSNR 33.31 -> 33.75`. | | Garden |
2. make use of scheduler (`MultiStepLR(optimizer, milestones=[20000, 30000], gamma=0.1)`) to adjust learning rate gives: `PSNR 33.75 -> 34.40`. | - | - |
3. increasing chunk size (`chunk: 8192 -> 81920`) during inference gives speedup: `FPS 4.x -> 6.2` | Ours | 25.13 |
4. random bkgd color (`color_bkgd_aug="random"`) for the `Lego` scene actually hurts: `PNSR 35.42 -> 34.38`
-->
...@@ -3,4 +3,7 @@ ...@@ -3,4 +3,7 @@
background-size: 150px 40px; background-size: 150px 40px;
height: 40px; height: 40px;
width: 150px; width: 150px;
}
code {
word-break: normal;
} }
\ No newline at end of file
nerfacc.accumulate\_along\_rays
===============================
.. currentmodule:: nerfacc
.. autofunction:: accumulate_along_rays
\ No newline at end of file
nerfacc.ray\_aabb\_intersect
============================
.. currentmodule:: nerfacc
.. autofunction:: ray_aabb_intersect
\ No newline at end of file
nerfacc.render\_visibility
==========================
.. currentmodule:: nerfacc
.. autofunction:: render_visibility
\ No newline at end of file
OccupancyField nerfacc.render\_weight\_from\_alpha
=================================== ===================================
.. currentmodule:: nerfacc .. currentmodule:: nerfacc
.. autoclass:: OccupancyField .. autofunction:: render_weight_from_alpha
:members: \ No newline at end of file
:show-inheritance:
\ No newline at end of file
nerfacc.render\_weight\_from\_density
=====================================
.. currentmodule:: nerfacc
.. autofunction:: render_weight_from_density
\ No newline at end of file
nerfacc.unpack\_to\_ray\_indices
================================
.. currentmodule:: nerfacc
.. autofunction:: unpack_to_ray_indices
\ No newline at end of file
Occupancy Grid
===================================
.. currentmodule:: nerfacc
.. autoclass:: ContractionType
:members:
.. autoclass:: Grid
:members:
.. autoclass:: OccupancyGrid
:members:
Volumetric Rendering
===================================
In `nerfacc`, the volumetric rendering pipeline is broken down into 2 steps:
1. **Raymarching**: This is the process of shooting a ray through the scene and
generate samples along the way. To perform efficient volumetric rendering, here we aim
at skipping as many areas as possible. The emtpy space is skipped by using the cached
occupancy grid (see :class:`nerfacc.OccupancyGrid`), and the invisible space is skipped by
checking the transmittance of the ray while marching. Almost in all cases, those skipping
won't result in a noticeable loss of quality as they would contribute very little to the
final rendered image. But they will bring a significant speedup.
2. **Rendering**: This is the process of accumulating samples along the rays into final image.
In this step we also need to query the attributes (a.k.a. color and density) of those samples
generated by raymarching. Early stoping is supported in this step.
|
.. currentmodule:: nerfacc
.. autofunction:: ray_marching
.. autofunction:: rendering
Utils
===================================
.. currentmodule:: nerfacc
.. autosummary::
:nosignatures:
:toctree: generated/
ray_aabb_intersect
unpack_to_ray_indices
accumulate_along_rays
render_weight_from_density
render_weight_from_alpha
render_visibility
\ No newline at end of file
Volumetric Ray Marching
=========================
.. currentmodule:: nerfacc
.. autofunction:: ray_aabb_intersect
.. autofunction:: volumetric_marching
\ No newline at end of file
Volumetric Rendering
======================
.. currentmodule:: nerfacc
.. autofunction:: volumetric_rendering_pipeline
.. autofunction:: volumetric_rendering_steps
.. autofunction:: volumetric_rendering_weights
.. autofunction:: volumetric_rendering_accumulate
.. autofunction:: unpack_to_ray_indices
\ No newline at end of file
# Examples using nerfacc
## Installation
Extra dependencies are needed.
You should make sure that you are using version of pytorch that support CUDA 11.
```
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
```
Then install via pip
```
pip install nerfacc[examples]
```
To install locally
```
pip install -e .[examples]
```
# Copyright (c) Meta Platforms, Inc. and affiliates.
import math
import torch
class CachedIterDataset(torch.utils.data.IterableDataset):
def __init__(
self,
training: bool = False,
cache_n_repeat: int = 0,
):
self.training = training
self.cache_n_repeat = cache_n_repeat
self._cache = None
self._n_repeat = 0
def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches)."""
raise NotImplementedError
def preprocess(self, data):
"""Process the fetched / cached data with randomness."""
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = 0
iter_end = self.__len__()
else: # in a worker process
# split workload
per_worker = int(math.ceil(self.__len__() / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = worker_id * per_worker
iter_end = min(iter_start + per_worker, self.__len__())
if self.training:
for index in iter_start + torch.randperm(iter_end - iter_start):
yield self.__getitem__(index)
else:
for index in range(iter_start, iter_end):
yield self.__getitem__(index)
def __getitem__(self, index):
if (
self.training
and (self._cache is not None)
and (self._n_repeat < self.cache_n_repeat)
):
data = self._cache
self._n_repeat += 1
else:
data = self.fetch_data(index)
self._cache = data
self._n_repeat = 1
return self.preprocess(data)
import collections
import json import json
import os import os
...@@ -7,12 +6,7 @@ import numpy as np ...@@ -7,12 +6,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
Rays = collections.namedtuple("Rays", ("origins", "viewdirs")) from .utils import Rays
def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*(None if x is None else fn(x) for x in tup))
def _load_renderings(root_fp: str, subject_id: str, split: str): def _load_renderings(root_fp: str, subject_id: str, split: str):
...@@ -27,7 +21,9 @@ def _load_renderings(root_fp: str, subject_id: str, split: str): ...@@ -27,7 +21,9 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
) )
data_dir = os.path.join(root_fp, subject_id) data_dir = os.path.join(root_fp, subject_id)
with open(os.path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp: with open(
os.path.join(data_dir, "transforms_{}.json".format(split)), "r"
) as fp:
meta = json.load(fp) meta = json.load(fp)
images = [] images = []
camtoworlds = [] camtoworlds = []
...@@ -38,7 +34,9 @@ def _load_renderings(root_fp: str, subject_id: str, split: str): ...@@ -38,7 +34,9 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
fname = os.path.join(data_dir, frame["file_path"] + ".png") fname = os.path.join(data_dir, frame["file_path"] + ".png")
rgba = imageio.imread(fname) rgba = imageio.imread(fname)
timestamp = ( timestamp = (
frame["time"] if "time" in frame else float(i) / (len(meta["frames"]) - 1) frame["time"]
if "time" in frame
else float(i) / (len(meta["frames"]) - 1)
) )
timestamps.append(timestamp) timestamps.append(timestamp)
camtoworlds.append(frame["transform_matrix"]) camtoworlds.append(frame["transform_matrix"])
...@@ -93,15 +91,22 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -93,15 +91,22 @@ class SubjectLoader(torch.utils.data.Dataset):
self.num_rays = num_rays self.num_rays = num_rays
self.near = self.NEAR if near is None else near self.near = self.NEAR if near is None else near
self.far = self.FAR if far is None else far self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (split in ["train", "trainval"]) self.training = (num_rays is not None) and (
split in ["train", "trainval"]
)
self.color_bkgd_aug = color_bkgd_aug self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images self.batch_over_images = batch_over_images
self.images, self.camtoworlds, self.focal, self.timestamps = _load_renderings( (
root_fp, subject_id, split self.images,
) self.camtoworlds,
self.focal,
self.timestamps,
) = _load_renderings(root_fp, subject_id, split)
self.images = torch.from_numpy(self.images).to(torch.uint8) self.images = torch.from_numpy(self.images).to(torch.uint8)
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32) self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
self.timestamps = torch.from_numpy(self.timestamps).to(torch.float32)[:, None] self.timestamps = torch.from_numpy(self.timestamps).to(torch.float32)[
:, None
]
self.K = torch.tensor( self.K = torch.tensor(
[ [
[self.focal, 0, self.WIDTH / 2.0], [self.focal, 0, self.WIDTH / 2.0],
...@@ -198,7 +203,9 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -198,7 +203,9 @@ class SubjectLoader(torch.utils.data.Dataset):
# [n_cams, height, width, 3] # [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1) directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape) origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True) viewdirs = directions / torch.linalg.norm(
directions, dim=-1, keepdims=True
)
if self.training: if self.training:
origins = torch.reshape(origins, (num_rays, 3)) origins = torch.reshape(origins, (num_rays, 3))
......
import collections
import os
import sys
import imageio
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from .utils import Rays
_PATH = os.path.abspath(__file__)
sys.path.insert(
0, os.path.join(os.path.dirname(_PATH), "..", "pycolmap", "pycolmap")
)
from scene_manager import SceneManager
def _load_colmap(root_fp: str, subject_id: str, split: str, factor: int = 1):
assert factor in [1, 2, 4, 8]
data_dir = os.path.join(root_fp, subject_id)
colmap_dir = os.path.join(data_dir, "sparse/0/")
manager = SceneManager(colmap_dir)
manager.load_cameras()
manager.load_images()
# Assume shared intrinsics between all cameras.
cam = manager.cameras[1]
fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
K[:2, :] /= factor
# Extract extrinsic matrices in world-to-camera format.
imdata = manager.images
w2c_mats = []
bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
for k in imdata:
im = imdata[k]
rot = im.R()
trans = im.tvec.reshape(3, 1)
w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
w2c_mats.append(w2c)
w2c_mats = np.stack(w2c_mats, axis=0)
# Convert extrinsics to camera-to-world.
camtoworlds = np.linalg.inv(w2c_mats)
# Image names from COLMAP. No need for permuting the poses according to
# image names anymore.
image_names = [imdata[k].name for k in imdata]
# # Switch from COLMAP (right, down, fwd) to NeRF (right, up, back) frame.
# poses = poses @ np.diag([1, -1, -1, 1])
# Get distortion parameters.
type_ = cam.camera_type
if type_ == 0 or type_ == "SIMPLE_PINHOLE":
params = None
camtype = "perspective"
elif type_ == 1 or type_ == "PINHOLE":
params = None
camtype = "perspective"
if type_ == 2 or type_ == "SIMPLE_RADIAL":
params = {k: 0.0 for k in ["k1", "k2", "k3", "p1", "p2"]}
params["k1"] = cam.k1
camtype = "perspective"
elif type_ == 3 or type_ == "RADIAL":
params = {k: 0.0 for k in ["k1", "k2", "k3", "p1", "p2"]}
params["k1"] = cam.k1
params["k2"] = cam.k2
camtype = "perspective"
elif type_ == 4 or type_ == "OPENCV":
params = {k: 0.0 for k in ["k1", "k2", "k3", "p1", "p2"]}
params["k1"] = cam.k1
params["k2"] = cam.k2
params["p1"] = cam.p1
params["p2"] = cam.p2
camtype = "perspective"
elif type_ == 5 or type_ == "OPENCV_FISHEYE":
params = {k: 0.0 for k in ["k1", "k2", "k3", "k4"]}
params["k1"] = cam.k1
params["k2"] = cam.k2
params["k3"] = cam.k3
params["k4"] = cam.k4
camtype = "fisheye"
assert params is None, "Only support pinhole camera model."
# Previous NeRF results were generated with images sorted by filename,
# ensure metrics are reported on the same test set.
inds = np.argsort(image_names)
image_names = [image_names[i] for i in inds]
camtoworlds = camtoworlds[inds]
# Load images.
if factor > 1:
image_dir_suffix = f"_{factor}"
else:
image_dir_suffix = ""
colmap_image_dir = os.path.join(data_dir, "images")
image_dir = os.path.join(data_dir, "images" + image_dir_suffix)
for d in [image_dir, colmap_image_dir]:
if not os.path.exists(d):
raise ValueError(f"Image folder {d} does not exist.")
# Downsampled images may have different names vs images used for COLMAP,
# so we need to map between the two sorted lists of files.
colmap_files = sorted(os.listdir(colmap_image_dir))
image_files = sorted(os.listdir(image_dir))
colmap_to_image = dict(zip(colmap_files, image_files))
image_paths = [
os.path.join(image_dir, colmap_to_image[f]) for f in image_names
]
print("loading images")
images = [imageio.imread(x) for x in tqdm.tqdm(image_paths)]
images = np.stack(images, axis=0)
# Select the split.
all_indices = np.arange(images.shape[0])
split_indices = {
"test": all_indices[all_indices % 8 == 0],
"train": all_indices[all_indices % 8 != 0],
}
indices = split_indices[split]
# All per-image quantities must be re-indexed using the split indices.
images = images[indices]
camtoworlds = camtoworlds[indices]
return images, camtoworlds, K
class SubjectLoader(torch.utils.data.Dataset):
"""Single subject data loader for training and evaluation."""
SPLITS = ["train", "test"]
SUBJECT_IDS = [
"garden",
]
OPENGL_CAMERA = False
def __init__(
self,
subject_id: str,
root_fp: str,
split: str,
color_bkgd_aug: str = "white",
num_rays: int = None,
near: float = None,
far: float = None,
batch_over_images: bool = True,
factor: int = 1,
):
super().__init__()
assert split in self.SPLITS, "%s" % split
assert subject_id in self.SUBJECT_IDS, "%s" % subject_id
assert color_bkgd_aug in ["white", "black", "random"]
self.split = split
self.num_rays = num_rays
self.near = near
self.far = far
self.training = (num_rays is not None) and (
split in ["train", "trainval"]
)
self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images
self.images, self.camtoworlds, self.K = _load_colmap(
root_fp, subject_id, split, factor
)
self.images = torch.from_numpy(self.images).to(torch.uint8)
self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)
self.K = torch.tensor(self.K).to(torch.float32)
self.height, self.width = self.images.shape[1:3]
def __len__(self):
return len(self.images)
@torch.no_grad()
def __getitem__(self, index):
data = self.fetch_data(index)
data = self.preprocess(data)
return data
def preprocess(self, data):
"""Process the fetched / cached data with randomness."""
pixels, rays = data["rgb"], data["rays"]
if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
color_bkgd = torch.zeros(3, device=self.images.device)
else:
# just use white during inference
color_bkgd = torch.ones(3, device=self.images.device)
return {
"pixels": pixels, # [n_rays, 3] or [h, w, 3]
"rays": rays, # [n_rays,] or [h, w]
"color_bkgd": color_bkgd, # [3,]
**{k: v for k, v in data.items() if k not in ["rgb", "rays"]},
}
def update_num_rays(self, num_rays):
self.num_rays = num_rays
def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches)."""
num_rays = self.num_rays
if self.training:
if self.batch_over_images:
image_id = torch.randint(
0,
len(self.images),
size=(num_rays,),
device=self.images.device,
)
else:
image_id = [index]
x = torch.randint(
0, self.width, size=(num_rays,), device=self.images.device
)
y = torch.randint(
0, self.height, size=(num_rays,), device=self.images.device
)
else:
image_id = [index]
x, y = torch.meshgrid(
torch.arange(self.width, device=self.images.device),
torch.arange(self.height, device=self.images.device),
indexing="xy",
)
x = x.flatten()
y = y.flatten()
# generate rays
rgb = self.images[image_id, y, x] / 255.0 # (num_rays, 3)
c2w = self.camtoworlds[image_id] # (num_rays, 3, 4)
camera_dirs = F.pad(
torch.stack(
[
(x - self.K[0, 2] + 0.5) / self.K[0, 0],
(y - self.K[1, 2] + 0.5)
/ self.K[1, 1]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3]
# [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(
directions, dim=-1, keepdims=True
)
if self.training:
origins = torch.reshape(origins, (num_rays, 3))
viewdirs = torch.reshape(viewdirs, (num_rays, 3))
rgb = torch.reshape(rgb, (num_rays, 3))
else:
origins = torch.reshape(origins, (self.height, self.width, 3))
viewdirs = torch.reshape(viewdirs, (self.height, self.width, 3))
rgb = torch.reshape(rgb, (self.height, self.width, 3))
rays = Rays(origins=origins, viewdirs=viewdirs)
return {
"rgb": rgb, # [h, w, 3] or [num_rays, 3]
"rays": rays, # [h, w, 3] or [num_rays, 3]
}
...@@ -7,12 +7,7 @@ import numpy as np ...@@ -7,12 +7,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
Rays = collections.namedtuple("Rays", ("origins", "viewdirs")) from .utils import Rays
def namedtuple_map(fn, tup):
"""Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple."""
return type(tup)(*(None if x is None else fn(x) for x in tup))
def _load_renderings(root_fp: str, subject_id: str, split: str): def _load_renderings(root_fp: str, subject_id: str, split: str):
...@@ -27,7 +22,9 @@ def _load_renderings(root_fp: str, subject_id: str, split: str): ...@@ -27,7 +22,9 @@ def _load_renderings(root_fp: str, subject_id: str, split: str):
) )
data_dir = os.path.join(root_fp, subject_id) data_dir = os.path.join(root_fp, subject_id)
with open(os.path.join(data_dir, "transforms_{}.json".format(split)), "r") as fp: with open(
os.path.join(data_dir, "transforms_{}.json".format(split)), "r"
) as fp:
meta = json.load(fp) meta = json.load(fp)
images = [] images = []
camtoworlds = [] camtoworlds = []
...@@ -87,7 +84,9 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -87,7 +84,9 @@ class SubjectLoader(torch.utils.data.Dataset):
self.num_rays = num_rays self.num_rays = num_rays
self.near = self.NEAR if near is None else near self.near = self.NEAR if near is None else near
self.far = self.FAR if far is None else far self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (split in ["train", "trainval"]) self.training = (num_rays is not None) and (
split in ["train", "trainval"]
)
self.color_bkgd_aug = color_bkgd_aug self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images self.batch_over_images = batch_over_images
if split == "trainval": if split == "trainval":
...@@ -98,7 +97,9 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -98,7 +97,9 @@ class SubjectLoader(torch.utils.data.Dataset):
root_fp, subject_id, "val" root_fp, subject_id, "val"
) )
self.images = np.concatenate([_images_train, _images_val]) self.images = np.concatenate([_images_train, _images_val])
self.camtoworlds = np.concatenate([_camtoworlds_train, _camtoworlds_val]) self.camtoworlds = np.concatenate(
[_camtoworlds_train, _camtoworlds_val]
)
self.focal = _focal_train self.focal = _focal_train
else: else:
self.images, self.camtoworlds, self.focal = _load_renderings( self.images, self.camtoworlds, self.focal = _load_renderings(
...@@ -202,7 +203,9 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -202,7 +203,9 @@ class SubjectLoader(torch.utils.data.Dataset):
# [n_cams, height, width, 3] # [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1) directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape) origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True) viewdirs = directions / torch.linalg.norm(
directions, dim=-1, keepdims=True
)
if self.training: if self.training:
origins = torch.reshape(origins, (num_rays, 3)) origins = torch.reshape(origins, (num_rays, 3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment