Commit 96211bba authored by Ruilong Li's avatar Ruilong Li
Browse files

wtf

parent 65bebd64
...@@ -65,6 +65,7 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -65,6 +65,7 @@ class SubjectLoader(torch.utils.data.Dataset):
WIDTH, HEIGHT = 800, 800 WIDTH, HEIGHT = 800, 800
NEAR, FAR = 2.0, 6.0 NEAR, FAR = 2.0, 6.0
OPENGL_CAMERA = True
def __init__( def __init__(
self, self,
...@@ -186,15 +187,18 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -186,15 +187,18 @@ class SubjectLoader(torch.utils.data.Dataset):
camera_dirs = F.pad( camera_dirs = F.pad(
torch.stack( torch.stack(
[ [
(x - self.K[0, 2] + 0.5) / self.K[0, 0], (x - self.K[0, 2] + 0.5)
(y - self.K[1, 2] + 0.5) / self.K[1, 1], / self.K[0, 0]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
(y - self.K[1, 2] + 0.5)
/ self.K[1, 1]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
], ],
dim=-1, dim=-1,
), ),
(0, 1), (0, 1),
value=1, value=1,
) # [num_rays, 3] ) # [num_rays, 3]
camera_dirs[..., [1, 2]] *= -1 # opengl format
# [n_cams, height, width, 3] # [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1) directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map from datasets.nerf_synthetic import Rays, SubjectLoader, namedtuple_map
from radiance_fields.ngp import NGPradianceField from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering from nerfacc import OccupancyField, volumetric_rendering
...@@ -67,10 +67,10 @@ if __name__ == "__main__": ...@@ -67,10 +67,10 @@ if __name__ == "__main__":
# setup dataset # setup dataset
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id="lego", subject_id="mic",
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp="/home/ruilongli/data/nerf_synthetic/",
split="train", split="trainval",
num_rays=4096, num_rays=409600,
) )
train_dataset.images = train_dataset.images.to(device) train_dataset.images = train_dataset.images.to(device)
...@@ -85,7 +85,7 @@ if __name__ == "__main__": ...@@ -85,7 +85,7 @@ if __name__ == "__main__":
) )
test_dataset = SubjectLoader( test_dataset = SubjectLoader(
subject_id="lego", subject_id="mic",
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp="/home/ruilongli/data/nerf_synthetic/",
split="test", split="test",
num_rays=None, num_rays=None,
...@@ -144,12 +144,27 @@ if __name__ == "__main__": ...@@ -144,12 +144,27 @@ if __name__ == "__main__":
occ_eval_fn=occ_eval_fn, aabb=scene_aabb, resolution=128 occ_eval_fn=occ_eval_fn, aabb=scene_aabb, resolution=128
).to(device) ).to(device)
render_bkgd = torch.ones(3, device=device)
# training # training
step = 0 step = 0
tic = time.time() tic = time.time()
data_time = 0 data_time = 0
tic_data = time.time() tic_data = time.time()
for epoch in range(400):
weights_image_ids = torch.ones((len(train_dataset.images),), device=device)
weights_xs = torch.ones(
(train_dataset.WIDTH,),
device=device,
)
weights_ys = torch.ones(
(train_dataset.HEIGHT,),
device=device,
)
for epoch in range(40000000):
data = train_dataset[0]
for i in range(len(train_dataset)): for i in range(len(train_dataset)):
data = train_dataset[i] data = train_dataset[i]
data_time += time.time() - tic_data data_time += time.time() - tic_data
...@@ -162,53 +177,66 @@ if __name__ == "__main__": ...@@ -162,53 +177,66 @@ if __name__ == "__main__":
pixels = data["pixels"].to(device) pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device) render_bkgd = data["color_bkgd"].to(device)
# update occupancy grid # # update occupancy grid
occ_field.every_n_step(step) # occ_field.every_n_step(step)
rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image( render_est_n_samples = 2**16 * 16 if radiance_field.training else None
radiance_field, rays, render_bkgd volumetric_rendering(
query_fn=radiance_field.forward, # {x, dir} -> {rgb, density}
rays_o=rays.origins,
rays_d=rays.viewdirs,
scene_aabb=occ_field.aabb,
scene_occ_binary=occ_field.occ_grid_binary,
scene_resolution=occ_field.resolution,
render_bkgd=render_bkgd,
render_n_samples=render_n_samples,
render_est_n_samples=render_est_n_samples, # memory control: wrost case
) )
num_rays = len(pixels)
num_rays = int(num_rays * (2**16 / float(compact_counter)))
num_rays = int(math.ceil(num_rays / 128.0) * 128)
train_dataset.update_num_rays(num_rays)
# compute loss # rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) # radiance_field, rays, render_bkgd
# )
# num_rays = len(pixels)
# num_rays = int(num_rays * (2**16 / float(compact_counter)))
# num_rays = int(math.ceil(num_rays / 128.0) * 128)
# train_dataset.update_num_rays(num_rays)
# # compute loss
# loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad() # optimizer.zero_grad()
(loss * 128.0).backward() # (loss * 128.0).backward()
optimizer.step() # optimizer.step()
scheduler.step() # scheduler.step()
if step % 50 == 0: if step % 50 == 0:
elapsed_time = time.time() - tic elapsed_time = time.time() - tic
print( print(
f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | " f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | "
f"loss={loss:.5f} | " # f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | " # f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} " # f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} "
) )
if step % 35_000 == 0 and step > 0: # if step % 35_000 == 0 and step > 0:
# evaluation # # evaluation
radiance_field.eval() # radiance_field.eval()
psnrs = [] # psnrs = []
with torch.no_grad(): # with torch.no_grad():
for data in tqdm.tqdm(test_dataloader): # for data in tqdm.tqdm(test_dataloader):
# generate rays from data and the gt pixel color # # generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"]) # rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device) # pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device) # render_bkgd = data["color_bkgd"].to(device)
# rendering # # rendering
rgb, depth, acc, alive_ray_mask, _, _ = render_image( # rgb, depth, acc, alive_ray_mask, _, _ = render_image(
radiance_field, rays, render_bkgd # radiance_field, rays, render_bkgd
) # )
mse = F.mse_loss(rgb, pixels) # mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0) # psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item()) # psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs) # psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}") # print(f"evaluation: {psnr_avg=}")
tic_data = time.time() tic_data = time.time()
step += 1 step += 1
......
...@@ -19,8 +19,8 @@ inline __host__ __device__ void _ray_aabb_intersect( ...@@ -19,8 +19,8 @@ inline __host__ __device__ void _ray_aabb_intersect(
if (tymin > tymax) __swap(tymin, tymax); if (tymin > tymax) __swap(tymin, tymax);
if (tmin > tymax || tymin > tmax){ if (tmin > tymax || tymin > tmax){
*near = std::numeric_limits<scalar_t>::max(); *near = 1e10;
*far = std::numeric_limits<scalar_t>::max(); *far = 1e10;
return; return;
} }
...@@ -32,8 +32,8 @@ inline __host__ __device__ void _ray_aabb_intersect( ...@@ -32,8 +32,8 @@ inline __host__ __device__ void _ray_aabb_intersect(
if (tzmin > tzmax) __swap(tzmin, tzmax); if (tzmin > tzmax) __swap(tzmin, tzmax);
if (tmin > tzmax || tzmin > tmax){ if (tmin > tzmax || tzmin > tmax){
*near = std::numeric_limits<scalar_t>::max(); *near = 1e10;
*far = std::numeric_limits<scalar_t>::max(); *far = 1e10;
return; return;
} }
...@@ -103,7 +103,7 @@ std::vector<torch::Tensor> ray_aabb_intersect( ...@@ -103,7 +103,7 @@ std::vector<torch::Tensor> ray_aabb_intersect(
AT_DISPATCH_FLOATING_TYPES_AND_HALF( AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "ray_aabb_intersect", rays_o.scalar_type(), "ray_aabb_intersect",
([&] { ([&] {
kernel_ray_aabb_intersect<scalar_t><<<blocks, threads>>>( kernel_ray_aabb_intersect<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
N, N,
rays_o.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(),
rays_d.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(),
......
...@@ -16,7 +16,7 @@ std::vector<torch::Tensor> ray_marching( ...@@ -16,7 +16,7 @@ std::vector<torch::Tensor> ray_marching(
const torch::Tensor t_max, const torch::Tensor t_max,
// density grid // density grid
const torch::Tensor aabb, const torch::Tensor aabb,
const torch::Tensor resolution, const pybind11::list resolution,
const torch::Tensor occ_binary, const torch::Tensor occ_binary,
// sampling // sampling
const int max_total_samples, const int max_total_samples,
......
#include <pybind11/pybind11.h>
#include "include/helpers_cuda.h" #include "include/helpers_cuda.h"
inline __device__ int cascaded_grid_idx_at( inline __device__ int cascaded_grid_idx_at(
const float x, const float y, const float z, const float x, const float y, const float z,
const int* resolution, const float* aabb const int resx, const int resy, const int resz,
const float* aabb
) { ) {
// TODO(ruilongli): if the x, y, z is outside the aabb, it will be clipped into aabb!!! We should just return false // TODO(ruilongli): if the x, y, z is outside the aabb, it will be clipped into aabb!!! We should just return false
int ix = (int)(((x - aabb[0]) / (aabb[3] - aabb[0])) * resolution[0]); int ix = (int)(((x - aabb[0]) / (aabb[3] - aabb[0])) * resx);
int iy = (int)(((y - aabb[1]) / (aabb[4] - aabb[1])) * resolution[1]); int iy = (int)(((y - aabb[1]) / (aabb[4] - aabb[1])) * resy);
int iz = (int)(((z - aabb[2]) / (aabb[5] - aabb[2])) * resolution[2]); int iz = (int)(((z - aabb[2]) / (aabb[5] - aabb[2])) * resz);
ix = __clamp(ix, 0, resolution[0]-1); ix = __clamp(ix, 0, resx-1);
iy = __clamp(iy, 0, resolution[1]-1); iy = __clamp(iy, 0, resy-1);
iz = __clamp(iz, 0, resolution[2]-1); iz = __clamp(iz, 0, resz-1);
int idx = ix * resolution[1] * resolution[2] + iy * resolution[2] + iz; int idx = ix * resx * resy + iy * resz + iz;
return idx; return idx;
} }
inline __device__ bool grid_occupied_at( inline __device__ bool grid_occupied_at(
const float x, const float y, const float z, const float x, const float y, const float z,
const int* resolution, const float* aabb, const bool* occ_binary const int resx, const int resy, const int resz,
const float* aabb, const bool* occ_binary
) { ) {
int idx = cascaded_grid_idx_at(x, y, z, resolution, aabb); int idx = cascaded_grid_idx_at(x, y, z, resx, resy, resz, aabb);
return occ_binary[idx]; return occ_binary[idx];
} }
...@@ -28,13 +31,13 @@ inline __device__ float distance_to_next_voxel( ...@@ -28,13 +31,13 @@ inline __device__ float distance_to_next_voxel(
float x, float y, float z, float x, float y, float z,
float dir_x, float dir_y, float dir_z, float dir_x, float dir_y, float dir_z,
float idir_x, float idir_y, float idir_z, float idir_x, float idir_y, float idir_z,
const int* resolution const int resx, const int resy, const int resz
) { // dda like step ) { // dda like step
// TODO: warning: expression has no effect? // TODO: warning: expression has no effect?
x, y, z = resolution[0] * x, resolution[1] * y, resolution[2] * z; x, y, z = resx * x, resy * y, resz * z;
float tx = ((floorf(x + 0.5f + 0.5f * __sign(dir_x)) - x) * idir_x) / resolution[0]; float tx = ((floorf(x + 0.5f + 0.5f * __sign(dir_x)) - x) * idir_x) / resx;
float ty = ((floorf(y + 0.5f + 0.5f * __sign(dir_y)) - y) * idir_y) / resolution[1]; float ty = ((floorf(y + 0.5f + 0.5f * __sign(dir_y)) - y) * idir_y) / resy;
float tz = ((floorf(z + 0.5f + 0.5f * __sign(dir_z)) - z) * idir_z) / resolution[2]; float tz = ((floorf(z + 0.5f + 0.5f * __sign(dir_z)) - z) * idir_z) / resz;
float t = min(min(tx, ty), tz); float t = min(min(tx, ty), tz);
return fmaxf(t, 0.0f); return fmaxf(t, 0.0f);
} }
...@@ -44,10 +47,11 @@ inline __device__ float advance_to_next_voxel( ...@@ -44,10 +47,11 @@ inline __device__ float advance_to_next_voxel(
float x, float y, float z, float x, float y, float z,
float dir_x, float dir_y, float dir_z, float dir_x, float dir_y, float dir_z,
float idir_x, float idir_y, float idir_z, float idir_x, float idir_y, float idir_z,
const int* resolution, float dt_min) { const int resx, const int resy, const int resz,
float dt_min) {
// Regular stepping (may be slower but matches non-empty space) // Regular stepping (may be slower but matches non-empty space)
float t_target = t + distance_to_next_voxel( float t_target = t + distance_to_next_voxel(
x, y, z, dir_x, dir_y, dir_z, idir_x, idir_y, idir_z, resolution x, y, z, dir_x, dir_y, dir_z, idir_x, idir_y, idir_z, resx, resy, resz
); );
do { do {
t += dt_min; t += dt_min;
...@@ -65,7 +69,9 @@ __global__ void kernel_raymarching( ...@@ -65,7 +69,9 @@ __global__ void kernel_raymarching(
const float* t_max, // shape (n_rays,) const float* t_max, // shape (n_rays,)
// density grid // density grid
const float* aabb, // [min_x, min_y, min_z, max_x, max_y, max_y] const float* aabb, // [min_x, min_y, min_z, max_x, max_y, max_y]
const int* resolution, // [reso_x, reso_y, reso_z] const int resx,
const int resy,
const int resz,
const bool* occ_binary, // shape (reso_x, reso_y, reso_z) const bool* occ_binary, // shape (reso_x, reso_y, reso_z)
// sampling // sampling
const int max_total_samples, const int max_total_samples,
...@@ -83,102 +89,102 @@ __global__ void kernel_raymarching( ...@@ -83,102 +89,102 @@ __global__ void kernel_raymarching(
) { ) {
CUDA_GET_THREAD_ID(i, n_rays); CUDA_GET_THREAD_ID(i, n_rays);
// locate // // locate
rays_o += i * 3; // rays_o += i * 3;
rays_d += i * 3; // rays_d += i * 3;
t_min += i; // t_min += i;
t_max += i; // t_max += i;
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2]; // const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2]; // const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz; // const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float near = t_min[0], far = t_max[0]; // const float near = t_min[0], far = t_max[0];
uint32_t ray_idx, base, marching_samples; // uint32_t ray_idx, base, marching_samples;
uint32_t j; // uint32_t j;
float t0, t1, t_mid; // float t0, t1, t_mid;
// first pass to compute an accurate number of steps // // first pass to compute an accurate number of steps
j = 0; // j = 0;
t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl? // t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
t1 = t0 + dt; // t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f; // t_mid = (t0 + t1) * 0.5f;
while (t_mid < far && j < max_per_ray_samples) { // while (t_mid < far && j < max_per_ray_samples) {
// current center // // current center
const float x = ox + t_mid * dx; // const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy; // const float y = oy + t_mid * dy;
const float z = oz + t_mid * dz; // const float z = oz + t_mid * dz;
if (grid_occupied_at(x, y, z, resolution, aabb, occ_binary)) { // if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
++j; // ++j;
// march to next sample // // march to next sample
t0 = t1; // t0 = t1;
t1 = t0 + dt; // t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f; // t_mid = (t0 + t1) * 0.5f;
} // }
else { // else {
// march to next sample // // march to next sample
t_mid = advance_to_next_voxel( // t_mid = advance_to_next_voxel(
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resolution, dt // t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
); // );
t0 = t_mid - dt * 0.5f; // t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f; // t1 = t_mid + dt * 0.5f;
} // }
} // }
if (j == 0) return; // if (j == 0) return;
marching_samples = j; // marching_samples = j;
base = atomicAdd(steps_counter, marching_samples); // base = atomicAdd(steps_counter, marching_samples);
if (base + marching_samples > max_total_samples) return; // if (base + marching_samples > max_total_samples) return;
ray_idx = atomicAdd(rays_counter, 1); // ray_idx = atomicAdd(rays_counter, 1);
// locate // // locate
frustum_origins += base * 3; // frustum_origins += base * 3;
frustum_dirs += base * 3; // frustum_dirs += base * 3;
frustum_starts += base; // frustum_starts += base;
frustum_ends += base; // frustum_ends += base;
// Second round // // Second round
j = 0; // j = 0;
t0 = near; // t0 = near;
t1 = t0 + dt; // t1 = t0 + dt;
t_mid = (t0 + t1) / 2.; // t_mid = (t0 + t1) / 2.;
while (t_mid < far && j < marching_samples) { // while (t_mid < far && j < marching_samples) {
// current center // // current center
const float x = ox + t_mid * dx; // const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy; // const float y = oy + t_mid * dy;
const float z = oz + t_mid * dz; // const float z = oz + t_mid * dz;
if (grid_occupied_at(x, y, z, resolution, aabb, occ_binary)) { // if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
frustum_origins[j * 3 + 0] = ox; // frustum_origins[j * 3 + 0] = ox;
frustum_origins[j * 3 + 1] = oy; // frustum_origins[j * 3 + 1] = oy;
frustum_origins[j * 3 + 2] = oz; // frustum_origins[j * 3 + 2] = oz;
frustum_dirs[j * 3 + 0] = dx; // frustum_dirs[j * 3 + 0] = dx;
frustum_dirs[j * 3 + 1] = dy; // frustum_dirs[j * 3 + 1] = dy;
frustum_dirs[j * 3 + 2] = dz; // frustum_dirs[j * 3 + 2] = dz;
frustum_starts[j] = t0; // frustum_starts[j] = t0;
frustum_ends[j] = t1; // frustum_ends[j] = t1;
++j; // ++j;
// march to next sample // // march to next sample
t0 = t1; // t0 = t1;
t1 = t0 + dt; // t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f; // t_mid = (t0 + t1) * 0.5f;
} // }
else { // else {
// march to next sample // // march to next sample
t_mid = advance_to_next_voxel( // t_mid = advance_to_next_voxel(
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resolution, dt // t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
); // );
t0 = t_mid - dt * 0.5f; // t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f; // t1 = t_mid + dt * 0.5f;
} // }
} // }
packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d} // packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d}
packed_info[ray_idx * 3 + 1] = base; // point idx start. // packed_info[ray_idx * 3 + 1] = base; // point idx start.
packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples). // packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
return; return;
} }
...@@ -220,67 +226,69 @@ std::vector<torch::Tensor> ray_marching( ...@@ -220,67 +226,69 @@ std::vector<torch::Tensor> ray_marching(
const torch::Tensor t_max, const torch::Tensor t_max,
// density grid // density grid
const torch::Tensor aabb, const torch::Tensor aabb,
const torch::Tensor resolution, const pybind11::list resolution,
const torch::Tensor occ_binary, const torch::Tensor occ_binary,
// sampling // sampling
const int max_total_samples, const int max_total_samples,
const int max_per_ray_samples, const int max_per_ray_samples,
const float dt const float dt
) { ) {
DEVICE_GUARD(rays_o); // DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o); // CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d); // CHECK_INPUT(rays_d);
CHECK_INPUT(t_min); // CHECK_INPUT(t_min);
CHECK_INPUT(t_max); // CHECK_INPUT(t_max);
CHECK_INPUT(aabb); // CHECK_INPUT(aabb);
CHECK_INPUT(resolution); // CHECK_INPUT(occ_binary);
CHECK_INPUT(occ_binary);
const int n_rays = rays_o.size(0); // const int n_rays = rays_o.size(0);
const int threads = 256; // // const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); // // const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter // // helper counter
torch::Tensor steps_counter = torch::zeros( // torch::Tensor steps_counter = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32)); // {1}, rays_o.options().dtype(torch::kInt32));
torch::Tensor rays_counter = torch::zeros( // torch::Tensor rays_counter = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32)); // {1}, rays_o.options().dtype(torch::kInt32));
// output frustum samples // // output frustum samples
torch::Tensor packed_info = torch::zeros( // torch::Tensor packed_info = torch::zeros(
{n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples // {n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples
torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options()); // torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options());
torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options()); // torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options());
torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options()); // torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options());
torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options()); // torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
kernel_raymarching<<<blocks, threads>>>( // kernel_raymarching<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays // // rays
n_rays, // n_rays,
rays_o.data_ptr<float>(), // rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(), // rays_d.data_ptr<float>(),
t_min.data_ptr<float>(), // t_min.data_ptr<float>(),
t_max.data_ptr<float>(), // t_max.data_ptr<float>(),
// density grid // // density grid
aabb.data_ptr<float>(), // aabb.data_ptr<float>(),
resolution.data_ptr<int>(), // resolution[0].cast<int>(),
occ_binary.data_ptr<bool>(), // resolution[1].cast<int>(),
// sampling // resolution[2].cast<int>(),
max_total_samples, // occ_binary.data_ptr<bool>(),
max_per_ray_samples, // // sampling
dt, // max_total_samples,
// writable helpers // max_per_ray_samples,
steps_counter.data_ptr<int>(), // total samples. // dt,
rays_counter.data_ptr<int>(), // total rays. // // writable helpers
packed_info.data_ptr<int>(), // steps_counter.data_ptr<int>(), // total samples.
frustum_origins.data_ptr<float>(), // rays_counter.data_ptr<int>(), // total rays.
frustum_dirs.data_ptr<float>(), // packed_info.data_ptr<int>(),
frustum_starts.data_ptr<float>(), // frustum_origins.data_ptr<float>(),
frustum_ends.data_ptr<float>() // frustum_dirs.data_ptr<float>(),
); // frustum_starts.data_ptr<float>(),
// frustum_ends.data_ptr<float>()
// );
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter}; // return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
return {};
} }
...@@ -22,7 +22,8 @@ def volumetric_rendering( ...@@ -22,7 +22,8 @@ def volumetric_rendering(
device = rays_o.device device = rays_o.device
if render_bkgd is None: if render_bkgd is None:
render_bkgd = torch.ones(3, device=device) render_bkgd = torch.ones(3, device=device)
scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
# scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
rays_o = rays_o.contiguous() rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous() rays_d = rays_d.contiguous()
...@@ -40,18 +41,17 @@ def volumetric_rendering( ...@@ -40,18 +41,17 @@ def volumetric_rendering(
) )
with torch.no_grad(): with torch.no_grad():
# TODO: avoid clamp here. kinda stupid
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb) t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb)
t_min = torch.clamp(t_min, max=1e10) # t_min = torch.clamp(t_min, max=1e10)
t_max = torch.clamp(t_max, max=1e10) # t_max = torch.clamp(t_max, max=1e10)
( (
packed_info, # packed_info,
frustum_origins, # frustum_origins,
frustum_dirs, # frustum_dirs,
frustum_starts, # frustum_starts,
frustum_ends, # frustum_ends,
steps_counter, # steps_counter,
) = ray_marching( ) = ray_marching(
# rays # rays
rays_o, rays_o,
...@@ -68,43 +68,43 @@ def volumetric_rendering( ...@@ -68,43 +68,43 @@ def volumetric_rendering(
render_step_size, render_step_size,
) )
# squeeze valid samples # # squeeze valid samples
total_samples = max(packed_info[:, -1].sum(), 1) # total_samples = max(packed_info[:, -1].sum(), 1)
total_samples = int(math.ceil(total_samples / 128.0)) * 128 # total_samples = int(math.ceil(total_samples / 128.0)) * 128
frustum_origins = frustum_origins[:total_samples] # frustum_origins = frustum_origins[:total_samples]
frustum_dirs = frustum_dirs[:total_samples] # frustum_dirs = frustum_dirs[:total_samples]
frustum_starts = frustum_starts[:total_samples] # frustum_starts = frustum_starts[:total_samples]
frustum_ends = frustum_ends[:total_samples] # frustum_ends = frustum_ends[:total_samples]
frustum_positions = ( # frustum_positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0 # frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
) # )
query_results = query_fn(frustum_positions, frustum_dirs, **kwargs) # query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
rgbs, densities = query_results[0], query_results[1] # rgbs, densities = query_results[0], query_results[1]
( # (
accumulated_weight, # accumulated_weight,
accumulated_depth, # accumulated_depth,
accumulated_color, # accumulated_color,
alive_ray_mask, # alive_ray_mask,
compact_steps_counter, # compact_steps_counter,
) = VolumeRenderer.apply( # ) = VolumeRenderer.apply(
packed_info, # packed_info,
frustum_starts, # frustum_starts,
frustum_ends, # frustum_ends,
densities.contiguous(), # densities.contiguous(),
rgbs.contiguous(), # rgbs.contiguous(),
) # )
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None]) # accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight) # accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
return ( # return (
accumulated_color, # accumulated_color,
accumulated_depth, # accumulated_depth,
accumulated_weight, # accumulated_weight,
alive_ray_mask, # alive_ray_mask,
steps_counter, # steps_counter,
compact_steps_counter, # compact_steps_counter,
) # )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment