Commit 16324602 authored by Ruilong Li's avatar Ruilong Li
Browse files

benchmark

parent 96211bba
...@@ -10,6 +10,12 @@ python examples/trainval.py ...@@ -10,6 +10,12 @@ python examples/trainval.py
## Performance Reference ## Performance Reference
| trainval (35k, 1<<16) | Lego | Mic | Materials |
| - | - | - | - |
| Time | 377s | 357s | 354s |
| PSNR | 36.08 | 36.58 | 29.63 |
Tested with the default settings on the Lego test set. Tested with the default settings on the Lego test set.
| Model | Split | PSNR | Train Time | Test Speed | GPU | Train Memory | | Model | Split | PSNR | Train Time | Test Speed | GPU | Train Memory |
......
...@@ -187,9 +187,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -187,9 +187,7 @@ class SubjectLoader(torch.utils.data.Dataset):
camera_dirs = F.pad( camera_dirs = F.pad(
torch.stack( torch.stack(
[ [
(x - self.K[0, 2] + 0.5) (x - self.K[0, 2] + 0.5) / self.K[0, 0],
/ self.K[0, 0]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
(y - self.K[1, 2] + 0.5) (y - self.K[1, 2] + 0.5)
/ self.K[1, 1] / self.K[1, 1]
* (-1.0 if self.OPENGL_CAMERA else 1.0), * (-1.0 if self.OPENGL_CAMERA else 1.0),
...@@ -197,7 +195,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -197,7 +195,7 @@ class SubjectLoader(torch.utils.data.Dataset):
dim=-1, dim=-1,
), ),
(0, 1), (0, 1),
value=1, value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3] ) # [num_rays, 3]
# [n_cams, height, width, 3] # [n_cams, height, width, 3]
......
...@@ -98,7 +98,6 @@ class NGPradianceField(BaseRadianceField): ...@@ -98,7 +98,6 @@ class NGPradianceField(BaseRadianceField):
}, },
) )
@torch.cuda.amp.autocast()
def query_density(self, x, return_feat: bool = False): def query_density(self, x, return_feat: bool = False):
bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0) bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
x = (x - bb_min) / (bb_max - bb_min) x = (x - bb_min) / (bb_max - bb_min)
...@@ -119,7 +118,6 @@ class NGPradianceField(BaseRadianceField): ...@@ -119,7 +118,6 @@ class NGPradianceField(BaseRadianceField):
else: else:
return density return density
@torch.cuda.amp.autocast()
def _query_rgb(self, dir, embedding): def _query_rgb(self, dir, embedding):
# tcnn requires directions in the range [0, 1] # tcnn requires directions in the range [0, 1]
if self.use_viewdirs: if self.use_viewdirs:
...@@ -131,7 +129,6 @@ class NGPradianceField(BaseRadianceField): ...@@ -131,7 +129,6 @@ class NGPradianceField(BaseRadianceField):
rgb = self.mlp_head(h).view(list(embedding.shape[:-1]) + [3]).to(embedding) rgb = self.mlp_head(h).view(list(embedding.shape[:-1]) + [3]).to(embedding)
return rgb return rgb
@torch.cuda.amp.autocast()
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
......
...@@ -5,13 +5,15 @@ import numpy as np ...@@ -5,13 +5,15 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from datasets.nerf_synthetic import Rays, SubjectLoader, namedtuple_map from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.ngp import NGPradianceField from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering from nerfacc import OccupancyField, volumetric_rendering
TARGET_SAMPLE_BATCH_SIZE = 1 << 16
def render_image(radiance_field, rays, render_bkgd):
def render_image(radiance_field, rays, render_bkgd, render_step_size):
"""Render the pixels of an image. """Render the pixels of an image.
Args: Args:
...@@ -32,7 +34,9 @@ def render_image(radiance_field, rays, render_bkgd): ...@@ -32,7 +34,9 @@ def render_image(radiance_field, rays, render_bkgd):
num_rays, _ = rays_shape num_rays, _ = rays_shape
results = [] results = []
chunk = torch.iinfo(torch.int32).max if radiance_field.training else 81920 chunk = torch.iinfo(torch.int32).max if radiance_field.training else 81920
render_est_n_samples = 2**16 * 16 if radiance_field.training else None render_est_n_samples = (
TARGET_SAMPLE_BATCH_SIZE * 16 if radiance_field.training else None
)
for i in range(0, num_rays, chunk): for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_results = volumetric_rendering( chunk_results = volumetric_rendering(
...@@ -45,6 +49,7 @@ def render_image(radiance_field, rays, render_bkgd): ...@@ -45,6 +49,7 @@ def render_image(radiance_field, rays, render_bkgd):
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
render_n_samples=render_n_samples, render_n_samples=render_n_samples,
render_est_n_samples=render_est_n_samples, # memory control: wrost case render_est_n_samples=render_est_n_samples, # memory control: wrost case
render_step_size=render_step_size,
) )
results.append(chunk_results) results.append(chunk_results)
rgb, depth, acc, alive_ray_mask, counter, compact_counter = [ rgb, depth, acc, alive_ray_mask, counter, compact_counter = [
...@@ -64,13 +69,14 @@ if __name__ == "__main__": ...@@ -64,13 +69,14 @@ if __name__ == "__main__":
torch.manual_seed(42) torch.manual_seed(42)
device = "cuda:0" device = "cuda:0"
scene = "lego"
# setup dataset # setup dataset
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id="mic", subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp="/home/ruilongli/data/nerf_synthetic/",
split="trainval", split="trainval",
num_rays=409600, num_rays=4096,
) )
train_dataset.images = train_dataset.images.to(device) train_dataset.images = train_dataset.images.to(device)
...@@ -85,7 +91,7 @@ if __name__ == "__main__": ...@@ -85,7 +91,7 @@ if __name__ == "__main__":
) )
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id="mic", subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp="/home/ruilongli/data/nerf_synthetic/",
split="test", split="test",
num_rays=None, num_rays=None,
...@@ -112,7 +118,7 @@ if __name__ == "__main__": ...@@ -112,7 +118,7 @@ if __name__ == "__main__":
render_n_samples = 1024 render_n_samples = 1024
render_step_size = ( render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples (scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
) ).item()
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
radiance_field.parameters(), radiance_field.parameters(),
...@@ -144,123 +150,75 @@ if __name__ == "__main__": ...@@ -144,123 +150,75 @@ if __name__ == "__main__":
occ_eval_fn=occ_eval_fn, aabb=scene_aabb, resolution=128 occ_eval_fn=occ_eval_fn, aabb=scene_aabb, resolution=128
).to(device) ).to(device)
render_bkgd = torch.ones(3, device=device)
# training # training
step = 0 step = 0
tic = time.time() tic = time.time()
data_time = 0 data_time = 0
tic_data = time.time() tic_data = time.time()
weights_image_ids = torch.ones((len(train_dataset.images),), device=device) for epoch in range(10000000):
weights_xs = torch.ones(
(train_dataset.WIDTH,),
device=device,
)
weights_ys = torch.ones(
(train_dataset.HEIGHT,),
device=device,
)
for epoch in range(40000000):
data = train_dataset[0]
for i in range(len(train_dataset)): for i in range(len(train_dataset)):
data = train_dataset[i] data = train_dataset[i]
data_time += time.time() - tic_data data_time += time.time() - tic_data
if step > 35_000:
print("training stops")
exit()
# generate rays from data and the gt pixel color # generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"]) # rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device) # pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device) render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# # update occupancy grid # update occupancy grid
# occ_field.every_n_step(step) occ_field.every_n_step(step)
render_est_n_samples = 2**16 * 16 if radiance_field.training else None rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
volumetric_rendering( radiance_field, rays, render_bkgd, render_step_size
query_fn=radiance_field.forward, # {x, dir} -> {rgb, density}
rays_o=rays.origins,
rays_d=rays.viewdirs,
scene_aabb=occ_field.aabb,
scene_occ_binary=occ_field.occ_grid_binary,
scene_resolution=occ_field.resolution,
render_bkgd=render_bkgd,
render_n_samples=render_n_samples,
render_est_n_samples=render_est_n_samples, # memory control: wrost case
) )
num_rays = len(pixels)
num_rays = int(
num_rays * (TARGET_SAMPLE_BATCH_SIZE / float(compact_counter.item()))
)
train_dataset.update_num_rays(num_rays)
# rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image( # compute loss
# radiance_field, rays, render_bkgd loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
# )
# num_rays = len(pixels)
# num_rays = int(num_rays * (2**16 / float(compact_counter)))
# num_rays = int(math.ceil(num_rays / 128.0) * 128)
# train_dataset.update_num_rays(num_rays)
# # compute loss
# loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
# optimizer.zero_grad() optimizer.zero_grad()
# (loss * 128.0).backward() (loss * 128).backward()
# optimizer.step() optimizer.step()
# scheduler.step() scheduler.step()
if step % 50 == 0: if step % 100 == 0:
elapsed_time = time.time() - tic elapsed_time = time.time() - tic
print( print(
f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | " f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | "
# f"loss={loss:.5f} | " f"loss={loss:.5f} | "
# f"alive_ray_mask={alive_ray_mask.long().sum():d} | " f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
# f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} " f"counter={counter.item():d} | compact_counter={compact_counter.item():d} | num_rays={len(pixels):d} "
) )
# if step % 35_000 == 0 and step > 0: # if time.time() - tic > 300:
# # evaluation if step == 35_000:
# radiance_field.eval() print("training stops")
# psnrs = [] # evaluation
# with torch.no_grad(): radiance_field.eval()
# for data in tqdm.tqdm(test_dataloader): psnrs = []
# # generate rays from data and the gt pixel color with torch.no_grad():
# rays = namedtuple_map(lambda x: x.to(device), data["rays"]) for data in tqdm.tqdm(test_dataloader):
# pixels = data["pixels"].to(device) # generate rays from data and the gt pixel color
# render_bkgd = data["color_bkgd"].to(device) rays = namedtuple_map(lambda x: x.to(device), data["rays"])
# # rendering pixels = data["pixels"].to(device)
# rgb, depth, acc, alive_ray_mask, _, _ = render_image( render_bkgd = data["color_bkgd"].to(device)
# radiance_field, rays, render_bkgd # rendering
# ) rgb, depth, acc, alive_ray_mask, _, _ = render_image(
# mse = F.mse_loss(rgb, pixels) radiance_field, rays, render_bkgd, render_step_size
# psnr = -10.0 * torch.log(mse) / np.log(10.0) )
# psnrs.append(psnr.item()) mse = F.mse_loss(rgb, pixels)
# psnr_avg = sum(psnrs) / len(psnrs) psnr = -10.0 * torch.log(mse) / np.log(10.0)
# print(f"evaluation: {psnr_avg=}") psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
exit()
tic_data = time.time() tic_data = time.time()
step += 1 step += 1
# "train"
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
# "train" batch_over_images=True
# elapsed_time=335.21s (data=68.99s) | step=30000 | loss=0.00028
# evaluation: psnr_avg=33.74970862388611 (6.23 it/s)
# "train" batch_over_images=True, schedule
# elapsed_time=296.30s (data=54.38s) | step=30000 | loss=0.00022
# evaluation: psnr_avg=34.3978275680542 (6.22 it/s)
# "trainval"
# elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021
# evaluation: psnr_avg=34.44980221748352 (6.61 it/s)
# "trainval" batch_over_images=True, schedule
# elapsed_time=291.42s (data=52.82s) | step=30000 | loss=0.00020
# evaluation: psnr_avg=35.41630497932434 (6.40 it/s)
# "trainval" batch_over_images=True, schedule 2**18
# evaluation: psnr_avg=36.24 (6.75 it/s)
...@@ -14,7 +14,8 @@ inline __device__ int cascaded_grid_idx_at( ...@@ -14,7 +14,8 @@ inline __device__ int cascaded_grid_idx_at(
ix = __clamp(ix, 0, resx-1); ix = __clamp(ix, 0, resx-1);
iy = __clamp(iy, 0, resy-1); iy = __clamp(iy, 0, resy-1);
iz = __clamp(iz, 0, resz-1); iz = __clamp(iz, 0, resz-1);
int idx = ix * resx * resy + iy * resz + iz; int idx = ix * resy * resz + iy * resz + iz;
// printf("(ix, iy, iz) = (%d, %d, %d)\n", ix, iy, iz);
return idx; return idx;
} }
...@@ -89,102 +90,102 @@ __global__ void kernel_raymarching( ...@@ -89,102 +90,102 @@ __global__ void kernel_raymarching(
) { ) {
CUDA_GET_THREAD_ID(i, n_rays); CUDA_GET_THREAD_ID(i, n_rays);
// // locate // locate
// rays_o += i * 3; rays_o += i * 3;
// rays_d += i * 3; rays_d += i * 3;
// t_min += i; t_min += i;
// t_max += i; t_max += i;
// const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
// const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
// const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
// const float near = t_min[0], far = t_max[0]; const float near = t_min[0], far = t_max[0];
// uint32_t ray_idx, base, marching_samples; uint32_t ray_idx, base, marching_samples;
// uint32_t j; uint32_t j;
// float t0, t1, t_mid; float t0, t1, t_mid;
// // first pass to compute an accurate number of steps // first pass to compute an accurate number of steps
// j = 0; j = 0;
// t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl? t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
// t1 = t0 + dt; t1 = t0 + dt;
// t_mid = (t0 + t1) * 0.5f; t_mid = (t0 + t1) * 0.5f;
// while (t_mid < far && j < max_per_ray_samples) { while (t_mid < far && j < max_per_ray_samples) {
// // current center // current center
// const float x = ox + t_mid * dx; const float x = ox + t_mid * dx;
// const float y = oy + t_mid * dy; const float y = oy + t_mid * dy;
// const float z = oz + t_mid * dz; const float z = oz + t_mid * dz;
// if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) { if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
// ++j; ++j;
// // march to next sample // march to next sample
// t0 = t1; t0 = t1;
// t1 = t0 + dt; t1 = t0 + dt;
// t_mid = (t0 + t1) * 0.5f; t_mid = (t0 + t1) * 0.5f;
// } }
// else { else {
// // march to next sample // march to next sample
// t_mid = advance_to_next_voxel( t_mid = advance_to_next_voxel(
// t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
// ); );
// t0 = t_mid - dt * 0.5f; t0 = t_mid - dt * 0.5f;
// t1 = t_mid + dt * 0.5f; t1 = t_mid + dt * 0.5f;
// } }
// } }
// if (j == 0) return; if (j == 0) return;
// marching_samples = j; marching_samples = j;
// base = atomicAdd(steps_counter, marching_samples); base = atomicAdd(steps_counter, marching_samples);
// if (base + marching_samples > max_total_samples) return; if (base + marching_samples > max_total_samples) return;
// ray_idx = atomicAdd(rays_counter, 1); ray_idx = atomicAdd(rays_counter, 1);
// // locate // locate
// frustum_origins += base * 3; frustum_origins += base * 3;
// frustum_dirs += base * 3; frustum_dirs += base * 3;
// frustum_starts += base; frustum_starts += base;
// frustum_ends += base; frustum_ends += base;
// // Second round // Second round
// j = 0; j = 0;
// t0 = near; t0 = near;
// t1 = t0 + dt; t1 = t0 + dt;
// t_mid = (t0 + t1) / 2.; t_mid = (t0 + t1) / 2.;
// while (t_mid < far && j < marching_samples) { while (t_mid < far && j < marching_samples) {
// // current center // current center
// const float x = ox + t_mid * dx; const float x = ox + t_mid * dx;
// const float y = oy + t_mid * dy; const float y = oy + t_mid * dy;
// const float z = oz + t_mid * dz; const float z = oz + t_mid * dz;
// if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) { if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
// frustum_origins[j * 3 + 0] = ox; frustum_origins[j * 3 + 0] = ox;
// frustum_origins[j * 3 + 1] = oy; frustum_origins[j * 3 + 1] = oy;
// frustum_origins[j * 3 + 2] = oz; frustum_origins[j * 3 + 2] = oz;
// frustum_dirs[j * 3 + 0] = dx; frustum_dirs[j * 3 + 0] = dx;
// frustum_dirs[j * 3 + 1] = dy; frustum_dirs[j * 3 + 1] = dy;
// frustum_dirs[j * 3 + 2] = dz; frustum_dirs[j * 3 + 2] = dz;
// frustum_starts[j] = t0; frustum_starts[j] = t0;
// frustum_ends[j] = t1; frustum_ends[j] = t1;
// ++j; ++j;
// // march to next sample // march to next sample
// t0 = t1; t0 = t1;
// t1 = t0 + dt; t1 = t0 + dt;
// t_mid = (t0 + t1) * 0.5f; t_mid = (t0 + t1) * 0.5f;
// } }
// else { else {
// // march to next sample // march to next sample
// t_mid = advance_to_next_voxel( t_mid = advance_to_next_voxel(
// t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
// ); );
// t0 = t_mid - dt * 0.5f; t0 = t_mid - dt * 0.5f;
// t1 = t_mid + dt * 0.5f; t1 = t_mid + dt * 0.5f;
// } }
// } }
// packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d} packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d}
// packed_info[ray_idx * 3 + 1] = base; // point idx start. packed_info[ray_idx * 3 + 1] = base; // point idx start.
// packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples). packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
return; return;
} }
...@@ -233,62 +234,61 @@ std::vector<torch::Tensor> ray_marching( ...@@ -233,62 +234,61 @@ std::vector<torch::Tensor> ray_marching(
const int max_per_ray_samples, const int max_per_ray_samples,
const float dt const float dt
) { ) {
// DEVICE_GUARD(rays_o); DEVICE_GUARD(rays_o);
// CHECK_INPUT(rays_o); CHECK_INPUT(rays_o);
// CHECK_INPUT(rays_d); CHECK_INPUT(rays_d);
// CHECK_INPUT(t_min); CHECK_INPUT(t_min);
// CHECK_INPUT(t_max); CHECK_INPUT(t_max);
// CHECK_INPUT(aabb); CHECK_INPUT(aabb);
// CHECK_INPUT(occ_binary); CHECK_INPUT(occ_binary);
// const int n_rays = rays_o.size(0); const int n_rays = rays_o.size(0);
// // const int threads = 256; const int threads = 256;
// // const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// // helper counter // helper counter
// torch::Tensor steps_counter = torch::zeros( torch::Tensor steps_counter = torch::zeros(
// {1}, rays_o.options().dtype(torch::kInt32)); {1}, rays_o.options().dtype(torch::kInt32));
// torch::Tensor rays_counter = torch::zeros( torch::Tensor rays_counter = torch::zeros(
// {1}, rays_o.options().dtype(torch::kInt32)); {1}, rays_o.options().dtype(torch::kInt32));
// // output frustum samples // output frustum samples
// torch::Tensor packed_info = torch::zeros( torch::Tensor packed_info = torch::zeros(
// {n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples {n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples
// torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options()); torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options());
// torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options()); torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options());
// torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options()); torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options());
// torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options()); torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
// kernel_raymarching<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( kernel_raymarching<<<blocks, threads>>>(
// // rays // rays
// n_rays, n_rays,
// rays_o.data_ptr<float>(), rays_o.data_ptr<float>(),
// rays_d.data_ptr<float>(), rays_d.data_ptr<float>(),
// t_min.data_ptr<float>(), t_min.data_ptr<float>(),
// t_max.data_ptr<float>(), t_max.data_ptr<float>(),
// // density grid // density grid
// aabb.data_ptr<float>(), aabb.data_ptr<float>(),
// resolution[0].cast<int>(), resolution[0].cast<int>(),
// resolution[1].cast<int>(), resolution[1].cast<int>(),
// resolution[2].cast<int>(), resolution[2].cast<int>(),
// occ_binary.data_ptr<bool>(), occ_binary.data_ptr<bool>(),
// // sampling // sampling
// max_total_samples, max_total_samples,
// max_per_ray_samples, max_per_ray_samples,
// dt, dt,
// // writable helpers // writable helpers
// steps_counter.data_ptr<int>(), // total samples. steps_counter.data_ptr<int>(), // total samples.
// rays_counter.data_ptr<int>(), // total rays. rays_counter.data_ptr<int>(), // total rays.
// packed_info.data_ptr<int>(), packed_info.data_ptr<int>(),
// frustum_origins.data_ptr<float>(), frustum_origins.data_ptr<float>(),
// frustum_dirs.data_ptr<float>(), frustum_dirs.data_ptr<float>(),
// frustum_starts.data_ptr<float>(), frustum_starts.data_ptr<float>(),
// frustum_ends.data_ptr<float>() frustum_ends.data_ptr<float>()
// ); );
// return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter}; return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
return {};
} }
...@@ -72,6 +72,7 @@ class OccupancyField(nn.Module): ...@@ -72,6 +72,7 @@ class OccupancyField(nn.Module):
self.register_buffer("aabb", aabb) self.register_buffer("aabb", aabb)
self.resolution = resolution self.resolution = resolution
self.register_buffer("resolution_tensor", torch.tensor(resolution))
self.num_dim = num_dim self.num_dim = num_dim
self.num_cells = torch.tensor(resolution).prod().item() self.num_cells = torch.tensor(resolution).prod().item()
...@@ -107,7 +108,6 @@ class OccupancyField(nn.Module): ...@@ -107,7 +108,6 @@ class OccupancyField(nn.Module):
if n < len(occupied_indices): if n < len(occupied_indices):
selector = torch.randint(len(occupied_indices), (n,), device=device) selector = torch.randint(len(occupied_indices), (n,), device=device)
occupied_indices = occupied_indices[selector] occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0) indices = torch.cat([uniform_indices, occupied_indices], dim=0)
return indices return indices
...@@ -129,19 +129,19 @@ class OccupancyField(nn.Module): ...@@ -129,19 +129,19 @@ class OccupancyField(nn.Module):
stage we change the sampling strategy to 1/4 unifromly sampled cells stage we change the sampling strategy to 1/4 unifromly sampled cells
together with 1/4 occupied cells. together with 1/4 occupied cells.
""" """
resolution = torch.tensor(self.resolution).to(self.occ_grid.device)
# sample cells # sample cells
if step < warmup_steps: if step < warmup_steps:
indices = self._get_all_cells() indices = self._get_all_cells()
else: else:
N = resolution.prod().item() // 4 N = self.num_cells // 4
indices = self._sample_uniform_and_occupied_cells(N) indices = self._sample_uniform_and_occupied_cells(N)
# infer occupancy: density * step_size # infer occupancy: density * step_size
tmp_occ_grid = -torch.ones_like(self.occ_grid) tmp_occ_grid = -torch.ones_like(self.occ_grid)
grid_coords = self.grid_coords[indices] grid_coords = self.grid_coords[indices]
x = (grid_coords + torch.rand_like(grid_coords.float())) / resolution x = (
grid_coords + torch.rand_like(grid_coords.float())
) / self.resolution_tensor
bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0) bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
x = x * (bb_max - bb_min) + bb_min x = x * (bb_max - bb_min) + bb_min
tmp_occ_grid[indices] = self.occ_eval_fn(x).squeeze(-1) tmp_occ_grid[indices] = self.occ_eval_fn(x).squeeze(-1)
...@@ -152,8 +152,8 @@ class OccupancyField(nn.Module): ...@@ -152,8 +152,8 @@ class OccupancyField(nn.Module):
self.occ_grid[ema_mask] * ema_decay, tmp_occ_grid[ema_mask] self.occ_grid[ema_mask] * ema_decay, tmp_occ_grid[ema_mask]
) )
self.occ_grid_mean = self.occ_grid.mean() self.occ_grid_mean = self.occ_grid.mean()
self.occ_grid_binary = self.occ_grid > min( self.occ_grid_binary = self.occ_grid > torch.clamp(
self.occ_grid_mean.item(), occ_threshold self.occ_grid_mean, max=occ_threshold
) )
@torch.no_grad() @torch.no_grad()
......
...@@ -16,6 +16,7 @@ def volumetric_rendering( ...@@ -16,6 +16,7 @@ def volumetric_rendering(
render_bkgd: torch.Tensor = None, render_bkgd: torch.Tensor = None,
render_n_samples: int = 1024, render_n_samples: int = 1024,
render_est_n_samples: int = None, render_est_n_samples: int = None,
render_step_size: int = None,
**kwargs, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""A *fast* version of differentiable volumetric rendering.""" """A *fast* version of differentiable volumetric rendering."""
...@@ -23,8 +24,6 @@ def volumetric_rendering( ...@@ -23,8 +24,6 @@ def volumetric_rendering(
if render_bkgd is None: if render_bkgd is None:
render_bkgd = torch.ones(3, device=device) render_bkgd = torch.ones(3, device=device)
# scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
rays_o = rays_o.contiguous() rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous() rays_d = rays_d.contiguous()
scene_aabb = scene_aabb.contiguous() scene_aabb = scene_aabb.contiguous()
...@@ -36,22 +35,22 @@ def volumetric_rendering( ...@@ -36,22 +35,22 @@ def volumetric_rendering(
render_total_samples = n_rays * render_n_samples render_total_samples = n_rays * render_n_samples
else: else:
render_total_samples = render_est_n_samples render_total_samples = render_est_n_samples
if render_step_size is None:
# Note: CPU<->GPU is not idea, try to pre-define it outside this function.
render_step_size = ( render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples (scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
) )
with torch.no_grad(): with torch.no_grad():
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb) t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb)
# t_min = torch.clamp(t_min, max=1e10)
# t_max = torch.clamp(t_max, max=1e10)
( (
# packed_info, packed_info,
# frustum_origins, frustum_origins,
# frustum_dirs, frustum_dirs,
# frustum_starts, frustum_starts,
# frustum_ends, frustum_ends,
# steps_counter, steps_counter,
) = ray_marching( ) = ray_marching(
# rays # rays
rays_o, rays_o,
...@@ -68,43 +67,41 @@ def volumetric_rendering( ...@@ -68,43 +67,41 @@ def volumetric_rendering(
render_step_size, render_step_size,
) )
# # squeeze valid samples # squeeze valid samples
# total_samples = max(packed_info[:, -1].sum(), 1) total_samples = max(packed_info[:, -1].sum(), 1)
# total_samples = int(math.ceil(total_samples / 128.0)) * 128 frustum_origins = frustum_origins[:total_samples]
# frustum_origins = frustum_origins[:total_samples] frustum_dirs = frustum_dirs[:total_samples]
# frustum_dirs = frustum_dirs[:total_samples] frustum_starts = frustum_starts[:total_samples]
# frustum_starts = frustum_starts[:total_samples] frustum_ends = frustum_ends[:total_samples]
# frustum_ends = frustum_ends[:total_samples]
# frustum_positions = (
# frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
# )
# query_results = query_fn(frustum_positions, frustum_dirs, **kwargs) frustum_positions = (
# rgbs, densities = query_results[0], query_results[1] frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
# ( query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
# accumulated_weight, rgbs, densities = query_results[0], query_results[1]
# accumulated_depth, (
# accumulated_color, accumulated_weight,
# alive_ray_mask, accumulated_depth,
# compact_steps_counter, accumulated_color,
# ) = VolumeRenderer.apply( alive_ray_mask,
# packed_info, compact_steps_counter,
# frustum_starts, ) = VolumeRenderer.apply(
# frustum_ends, packed_info,
# densities.contiguous(), frustum_starts,
# rgbs.contiguous(), frustum_ends,
# ) densities.contiguous(),
rgbs.contiguous(),
)
# accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None]) accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
# accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight) accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
# return ( return (
# accumulated_color, accumulated_color,
# accumulated_depth, accumulated_depth,
# accumulated_weight, accumulated_weight,
# alive_ray_mask, alive_ray_mask,
# steps_counter, steps_counter,
# compact_steps_counter, compact_steps_counter,
# ) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment