Commit 16324602 authored by Ruilong Li's avatar Ruilong Li
Browse files

benchmark

parent 96211bba
......@@ -10,6 +10,12 @@ python examples/trainval.py
## Performance Reference
| trainval (35k, 1<<16) | Lego | Mic | Materials |
| - | - | - | - |
| Time | 377s | 357s | 354s |
| PSNR | 36.08 | 36.58 | 29.63 |
Tested with the default settings on the Lego test set.
| Model | Split | PSNR | Train Time | Test Speed | GPU | Train Memory |
......
......@@ -187,9 +187,7 @@ class SubjectLoader(torch.utils.data.Dataset):
camera_dirs = F.pad(
torch.stack(
[
(x - self.K[0, 2] + 0.5)
/ self.K[0, 0]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
(x - self.K[0, 2] + 0.5) / self.K[0, 0],
(y - self.K[1, 2] + 0.5)
/ self.K[1, 1]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
......@@ -197,7 +195,7 @@ class SubjectLoader(torch.utils.data.Dataset):
dim=-1,
),
(0, 1),
value=1,
value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3]
# [n_cams, height, width, 3]
......
......@@ -98,7 +98,6 @@ class NGPradianceField(BaseRadianceField):
},
)
@torch.cuda.amp.autocast()
def query_density(self, x, return_feat: bool = False):
bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
x = (x - bb_min) / (bb_max - bb_min)
......@@ -119,7 +118,6 @@ class NGPradianceField(BaseRadianceField):
else:
return density
@torch.cuda.amp.autocast()
def _query_rgb(self, dir, embedding):
# tcnn requires directions in the range [0, 1]
if self.use_viewdirs:
......@@ -131,7 +129,6 @@ class NGPradianceField(BaseRadianceField):
rgb = self.mlp_head(h).view(list(embedding.shape[:-1]) + [3]).to(embedding)
return rgb
@torch.cuda.amp.autocast()
def forward(
self,
positions: torch.Tensor,
......
......@@ -5,13 +5,15 @@ import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from datasets.nerf_synthetic import Rays, SubjectLoader, namedtuple_map
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering
TARGET_SAMPLE_BATCH_SIZE = 1 << 16
def render_image(radiance_field, rays, render_bkgd):
def render_image(radiance_field, rays, render_bkgd, render_step_size):
"""Render the pixels of an image.
Args:
......@@ -32,7 +34,9 @@ def render_image(radiance_field, rays, render_bkgd):
num_rays, _ = rays_shape
results = []
chunk = torch.iinfo(torch.int32).max if radiance_field.training else 81920
render_est_n_samples = 2**16 * 16 if radiance_field.training else None
render_est_n_samples = (
TARGET_SAMPLE_BATCH_SIZE * 16 if radiance_field.training else None
)
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_results = volumetric_rendering(
......@@ -45,6 +49,7 @@ def render_image(radiance_field, rays, render_bkgd):
render_bkgd=render_bkgd,
render_n_samples=render_n_samples,
render_est_n_samples=render_est_n_samples, # memory control: wrost case
render_step_size=render_step_size,
)
results.append(chunk_results)
rgb, depth, acc, alive_ray_mask, counter, compact_counter = [
......@@ -64,13 +69,14 @@ if __name__ == "__main__":
torch.manual_seed(42)
device = "cuda:0"
scene = "lego"
# setup dataset
train_dataset = SubjectLoader(
subject_id="mic",
subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="trainval",
num_rays=409600,
num_rays=4096,
)
train_dataset.images = train_dataset.images.to(device)
......@@ -85,7 +91,7 @@ if __name__ == "__main__":
)
test_dataset = SubjectLoader(
subject_id="mic",
subject_id=scene,
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="test",
num_rays=None,
......@@ -112,7 +118,7 @@ if __name__ == "__main__":
render_n_samples = 1024
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
)
).item()
optimizer = torch.optim.Adam(
radiance_field.parameters(),
......@@ -144,123 +150,75 @@ if __name__ == "__main__":
occ_eval_fn=occ_eval_fn, aabb=scene_aabb, resolution=128
).to(device)
render_bkgd = torch.ones(3, device=device)
# training
step = 0
tic = time.time()
data_time = 0
tic_data = time.time()
weights_image_ids = torch.ones((len(train_dataset.images),), device=device)
weights_xs = torch.ones(
(train_dataset.WIDTH,),
device=device,
)
weights_ys = torch.ones(
(train_dataset.HEIGHT,),
device=device,
)
for epoch in range(40000000):
data = train_dataset[0]
for epoch in range(10000000):
for i in range(len(train_dataset)):
data = train_dataset[i]
data_time += time.time() - tic_data
if step > 35_000:
print("training stops")
exit()
# generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device)
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
# pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"]
rays = data["rays"]
pixels = data["pixels"]
# # update occupancy grid
# occ_field.every_n_step(step)
# update occupancy grid
occ_field.every_n_step(step)
render_est_n_samples = 2**16 * 16 if radiance_field.training else None
volumetric_rendering(
query_fn=radiance_field.forward, # {x, dir} -> {rgb, density}
rays_o=rays.origins,
rays_d=rays.viewdirs,
scene_aabb=occ_field.aabb,
scene_occ_binary=occ_field.occ_grid_binary,
scene_resolution=occ_field.resolution,
render_bkgd=render_bkgd,
render_n_samples=render_n_samples,
render_est_n_samples=render_est_n_samples, # memory control: wrost case
rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
radiance_field, rays, render_bkgd, render_step_size
)
num_rays = len(pixels)
num_rays = int(
num_rays * (TARGET_SAMPLE_BATCH_SIZE / float(compact_counter.item()))
)
train_dataset.update_num_rays(num_rays)
# rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
# radiance_field, rays, render_bkgd
# )
# num_rays = len(pixels)
# num_rays = int(num_rays * (2**16 / float(compact_counter)))
# num_rays = int(math.ceil(num_rays / 128.0) * 128)
# train_dataset.update_num_rays(num_rays)
# # compute loss
# loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
# compute loss
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
# optimizer.zero_grad()
# (loss * 128.0).backward()
# optimizer.step()
# scheduler.step()
optimizer.zero_grad()
(loss * 128).backward()
optimizer.step()
scheduler.step()
if step % 50 == 0:
if step % 100 == 0:
elapsed_time = time.time() - tic
print(
f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | "
# f"loss={loss:.5f} | "
# f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
# f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"counter={counter.item():d} | compact_counter={compact_counter.item():d} | num_rays={len(pixels):d} "
)
# if step % 35_000 == 0 and step > 0:
# # evaluation
# radiance_field.eval()
# psnrs = []
# with torch.no_grad():
# for data in tqdm.tqdm(test_dataloader):
# # generate rays from data and the gt pixel color
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
# pixels = data["pixels"].to(device)
# render_bkgd = data["color_bkgd"].to(device)
# # rendering
# rgb, depth, acc, alive_ray_mask, _, _ = render_image(
# radiance_field, rays, render_bkgd
# )
# mse = F.mse_loss(rgb, pixels)
# psnr = -10.0 * torch.log(mse) / np.log(10.0)
# psnrs.append(psnr.item())
# psnr_avg = sum(psnrs) / len(psnrs)
# print(f"evaluation: {psnr_avg=}")
# if time.time() - tic > 300:
if step == 35_000:
print("training stops")
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for data in tqdm.tqdm(test_dataloader):
# generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device)
# rendering
rgb, depth, acc, alive_ray_mask, _, _ = render_image(
radiance_field, rays, render_bkgd, render_step_size
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
exit()
tic_data = time.time()
step += 1
# "train"
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
# "train" batch_over_images=True
# elapsed_time=335.21s (data=68.99s) | step=30000 | loss=0.00028
# evaluation: psnr_avg=33.74970862388611 (6.23 it/s)
# "train" batch_over_images=True, schedule
# elapsed_time=296.30s (data=54.38s) | step=30000 | loss=0.00022
# evaluation: psnr_avg=34.3978275680542 (6.22 it/s)
# "trainval"
# elapsed_time=289.94s (data=51.99s) | step=30000 | loss=0.00021
# evaluation: psnr_avg=34.44980221748352 (6.61 it/s)
# "trainval" batch_over_images=True, schedule
# elapsed_time=291.42s (data=52.82s) | step=30000 | loss=0.00020
# evaluation: psnr_avg=35.41630497932434 (6.40 it/s)
# "trainval" batch_over_images=True, schedule 2**18
# evaluation: psnr_avg=36.24 (6.75 it/s)
......@@ -14,7 +14,8 @@ inline __device__ int cascaded_grid_idx_at(
ix = __clamp(ix, 0, resx-1);
iy = __clamp(iy, 0, resy-1);
iz = __clamp(iz, 0, resz-1);
int idx = ix * resx * resy + iy * resz + iz;
int idx = ix * resy * resz + iy * resz + iz;
// printf("(ix, iy, iz) = (%d, %d, %d)\n", ix, iy, iz);
return idx;
}
......@@ -89,102 +90,102 @@ __global__ void kernel_raymarching(
) {
CUDA_GET_THREAD_ID(i, n_rays);
// // locate
// rays_o += i * 3;
// rays_d += i * 3;
// t_min += i;
// t_max += i;
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
// const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
// const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
// const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
// const float near = t_min[0], far = t_max[0];
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float near = t_min[0], far = t_max[0];
// uint32_t ray_idx, base, marching_samples;
// uint32_t j;
// float t0, t1, t_mid;
uint32_t ray_idx, base, marching_samples;
uint32_t j;
float t0, t1, t_mid;
// // first pass to compute an accurate number of steps
// j = 0;
// t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
// t1 = t0 + dt;
// t_mid = (t0 + t1) * 0.5f;
// first pass to compute an accurate number of steps
j = 0;
t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
// while (t_mid < far && j < max_per_ray_samples) {
// // current center
// const float x = ox + t_mid * dx;
// const float y = oy + t_mid * dy;
// const float z = oz + t_mid * dz;
while (t_mid < far && j < max_per_ray_samples) {
// current center
const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy;
const float z = oz + t_mid * dz;
// if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
// ++j;
// // march to next sample
// t0 = t1;
// t1 = t0 + dt;
// t_mid = (t0 + t1) * 0.5f;
// }
// else {
// // march to next sample
// t_mid = advance_to_next_voxel(
// t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
// );
// t0 = t_mid - dt * 0.5f;
// t1 = t_mid + dt * 0.5f;
// }
// }
// if (j == 0) return;
if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
++j;
// march to next sample
t0 = t1;
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
}
else {
// march to next sample
t_mid = advance_to_next_voxel(
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
}
}
if (j == 0) return;
// marching_samples = j;
// base = atomicAdd(steps_counter, marching_samples);
// if (base + marching_samples > max_total_samples) return;
// ray_idx = atomicAdd(rays_counter, 1);
marching_samples = j;
base = atomicAdd(steps_counter, marching_samples);
if (base + marching_samples > max_total_samples) return;
ray_idx = atomicAdd(rays_counter, 1);
// // locate
// frustum_origins += base * 3;
// frustum_dirs += base * 3;
// frustum_starts += base;
// frustum_ends += base;
// locate
frustum_origins += base * 3;
frustum_dirs += base * 3;
frustum_starts += base;
frustum_ends += base;
// // Second round
// j = 0;
// t0 = near;
// t1 = t0 + dt;
// t_mid = (t0 + t1) / 2.;
// Second round
j = 0;
t0 = near;
t1 = t0 + dt;
t_mid = (t0 + t1) / 2.;
// while (t_mid < far && j < marching_samples) {
// // current center
// const float x = ox + t_mid * dx;
// const float y = oy + t_mid * dy;
// const float z = oz + t_mid * dz;
while (t_mid < far && j < marching_samples) {
// current center
const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy;
const float z = oz + t_mid * dz;
// if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
// frustum_origins[j * 3 + 0] = ox;
// frustum_origins[j * 3 + 1] = oy;
// frustum_origins[j * 3 + 2] = oz;
// frustum_dirs[j * 3 + 0] = dx;
// frustum_dirs[j * 3 + 1] = dy;
// frustum_dirs[j * 3 + 2] = dz;
// frustum_starts[j] = t0;
// frustum_ends[j] = t1;
// ++j;
// // march to next sample
// t0 = t1;
// t1 = t0 + dt;
// t_mid = (t0 + t1) * 0.5f;
// }
// else {
// // march to next sample
// t_mid = advance_to_next_voxel(
// t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
// );
// t0 = t_mid - dt * 0.5f;
// t1 = t_mid + dt * 0.5f;
// }
// }
if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
frustum_origins[j * 3 + 0] = ox;
frustum_origins[j * 3 + 1] = oy;
frustum_origins[j * 3 + 2] = oz;
frustum_dirs[j * 3 + 0] = dx;
frustum_dirs[j * 3 + 1] = dy;
frustum_dirs[j * 3 + 2] = dz;
frustum_starts[j] = t0;
frustum_ends[j] = t1;
++j;
// march to next sample
t0 = t1;
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
}
else {
// march to next sample
t_mid = advance_to_next_voxel(
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
}
}
// packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d}
// packed_info[ray_idx * 3 + 1] = base; // point idx start.
// packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d}
packed_info[ray_idx * 3 + 1] = base; // point idx start.
packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
return;
}
......@@ -233,62 +234,61 @@ std::vector<torch::Tensor> ray_marching(
const int max_per_ray_samples,
const float dt
) {
// DEVICE_GUARD(rays_o);
DEVICE_GUARD(rays_o);
// CHECK_INPUT(rays_o);
// CHECK_INPUT(rays_d);
// CHECK_INPUT(t_min);
// CHECK_INPUT(t_max);
// CHECK_INPUT(aabb);
// CHECK_INPUT(occ_binary);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
CHECK_INPUT(aabb);
CHECK_INPUT(occ_binary);
// const int n_rays = rays_o.size(0);
const int n_rays = rays_o.size(0);
// // const int threads = 256;
// // const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// // helper counter
// torch::Tensor steps_counter = torch::zeros(
// {1}, rays_o.options().dtype(torch::kInt32));
// torch::Tensor rays_counter = torch::zeros(
// {1}, rays_o.options().dtype(torch::kInt32));
// helper counter
torch::Tensor steps_counter = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32));
torch::Tensor rays_counter = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32));
// // output frustum samples
// torch::Tensor packed_info = torch::zeros(
// {n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples
// torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options());
// torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options());
// torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options());
// torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
// output frustum samples
torch::Tensor packed_info = torch::zeros(
{n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples
torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options());
torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options());
torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options());
torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
// kernel_raymarching<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// // rays
// n_rays,
// rays_o.data_ptr<float>(),
// rays_d.data_ptr<float>(),
// t_min.data_ptr<float>(),
// t_max.data_ptr<float>(),
// // density grid
// aabb.data_ptr<float>(),
// resolution[0].cast<int>(),
// resolution[1].cast<int>(),
// resolution[2].cast<int>(),
// occ_binary.data_ptr<bool>(),
// // sampling
// max_total_samples,
// max_per_ray_samples,
// dt,
// // writable helpers
// steps_counter.data_ptr<int>(), // total samples.
// rays_counter.data_ptr<int>(), // total rays.
// packed_info.data_ptr<int>(),
// frustum_origins.data_ptr<float>(),
// frustum_dirs.data_ptr<float>(),
// frustum_starts.data_ptr<float>(),
// frustum_ends.data_ptr<float>()
// );
kernel_raymarching<<<blocks, threads>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// density grid
aabb.data_ptr<float>(),
resolution[0].cast<int>(),
resolution[1].cast<int>(),
resolution[2].cast<int>(),
occ_binary.data_ptr<bool>(),
// sampling
max_total_samples,
max_per_ray_samples,
dt,
// writable helpers
steps_counter.data_ptr<int>(), // total samples.
rays_counter.data_ptr<int>(), // total rays.
packed_info.data_ptr<int>(),
frustum_origins.data_ptr<float>(),
frustum_dirs.data_ptr<float>(),
frustum_starts.data_ptr<float>(),
frustum_ends.data_ptr<float>()
);
// return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
return {};
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
}
......@@ -72,6 +72,7 @@ class OccupancyField(nn.Module):
self.register_buffer("aabb", aabb)
self.resolution = resolution
self.register_buffer("resolution_tensor", torch.tensor(resolution))
self.num_dim = num_dim
self.num_cells = torch.tensor(resolution).prod().item()
......@@ -107,7 +108,6 @@ class OccupancyField(nn.Module):
if n < len(occupied_indices):
selector = torch.randint(len(occupied_indices), (n,), device=device)
occupied_indices = occupied_indices[selector]
indices = torch.cat([uniform_indices, occupied_indices], dim=0)
return indices
......@@ -129,19 +129,19 @@ class OccupancyField(nn.Module):
stage we change the sampling strategy to 1/4 unifromly sampled cells
together with 1/4 occupied cells.
"""
resolution = torch.tensor(self.resolution).to(self.occ_grid.device)
# sample cells
if step < warmup_steps:
indices = self._get_all_cells()
else:
N = resolution.prod().item() // 4
N = self.num_cells // 4
indices = self._sample_uniform_and_occupied_cells(N)
# infer occupancy: density * step_size
tmp_occ_grid = -torch.ones_like(self.occ_grid)
grid_coords = self.grid_coords[indices]
x = (grid_coords + torch.rand_like(grid_coords.float())) / resolution
x = (
grid_coords + torch.rand_like(grid_coords.float())
) / self.resolution_tensor
bb_min, bb_max = torch.split(self.aabb, [self.num_dim, self.num_dim], dim=0)
x = x * (bb_max - bb_min) + bb_min
tmp_occ_grid[indices] = self.occ_eval_fn(x).squeeze(-1)
......@@ -152,8 +152,8 @@ class OccupancyField(nn.Module):
self.occ_grid[ema_mask] * ema_decay, tmp_occ_grid[ema_mask]
)
self.occ_grid_mean = self.occ_grid.mean()
self.occ_grid_binary = self.occ_grid > min(
self.occ_grid_mean.item(), occ_threshold
self.occ_grid_binary = self.occ_grid > torch.clamp(
self.occ_grid_mean, max=occ_threshold
)
@torch.no_grad()
......
......@@ -16,6 +16,7 @@ def volumetric_rendering(
render_bkgd: torch.Tensor = None,
render_n_samples: int = 1024,
render_est_n_samples: int = None,
render_step_size: int = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""A *fast* version of differentiable volumetric rendering."""
......@@ -23,8 +24,6 @@ def volumetric_rendering(
if render_bkgd is None:
render_bkgd = torch.ones(3, device=device)
# scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous()
scene_aabb = scene_aabb.contiguous()
......@@ -36,22 +35,22 @@ def volumetric_rendering(
render_total_samples = n_rays * render_n_samples
else:
render_total_samples = render_est_n_samples
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
)
if render_step_size is None:
# Note: CPU<->GPU is not idea, try to pre-define it outside this function.
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
)
with torch.no_grad():
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb)
# t_min = torch.clamp(t_min, max=1e10)
# t_max = torch.clamp(t_max, max=1e10)
(
# packed_info,
# frustum_origins,
# frustum_dirs,
# frustum_starts,
# frustum_ends,
# steps_counter,
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
steps_counter,
) = ray_marching(
# rays
rays_o,
......@@ -68,43 +67,41 @@ def volumetric_rendering(
render_step_size,
)
# # squeeze valid samples
# total_samples = max(packed_info[:, -1].sum(), 1)
# total_samples = int(math.ceil(total_samples / 128.0)) * 128
# frustum_origins = frustum_origins[:total_samples]
# frustum_dirs = frustum_dirs[:total_samples]
# frustum_starts = frustum_starts[:total_samples]
# frustum_ends = frustum_ends[:total_samples]
# frustum_positions = (
# frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
# )
# squeeze valid samples
total_samples = max(packed_info[:, -1].sum(), 1)
frustum_origins = frustum_origins[:total_samples]
frustum_dirs = frustum_dirs[:total_samples]
frustum_starts = frustum_starts[:total_samples]
frustum_ends = frustum_ends[:total_samples]
# query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
# rgbs, densities = query_results[0], query_results[1]
frustum_positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
# (
# accumulated_weight,
# accumulated_depth,
# accumulated_color,
# alive_ray_mask,
# compact_steps_counter,
# ) = VolumeRenderer.apply(
# packed_info,
# frustum_starts,
# frustum_ends,
# densities.contiguous(),
# rgbs.contiguous(),
# )
query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
rgbs, densities = query_results[0], query_results[1]
(
accumulated_weight,
accumulated_depth,
accumulated_color,
alive_ray_mask,
compact_steps_counter,
) = VolumeRenderer.apply(
packed_info,
frustum_starts,
frustum_ends,
densities.contiguous(),
rgbs.contiguous(),
)
# accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
# accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
# return (
# accumulated_color,
# accumulated_depth,
# accumulated_weight,
# alive_ray_mask,
# steps_counter,
# compact_steps_counter,
# )
return (
accumulated_color,
accumulated_depth,
accumulated_weight,
alive_ray_mask,
steps_counter,
compact_steps_counter,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment