"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fee93c81eb7c5e9fe1618f858f1e369567170edc"
Commit 96211bba authored by Ruilong Li's avatar Ruilong Li
Browse files

wtf

parent 65bebd64
......@@ -65,6 +65,7 @@ class SubjectLoader(torch.utils.data.Dataset):
WIDTH, HEIGHT = 800, 800
NEAR, FAR = 2.0, 6.0
OPENGL_CAMERA = True
def __init__(
self,
......@@ -186,15 +187,18 @@ class SubjectLoader(torch.utils.data.Dataset):
camera_dirs = F.pad(
torch.stack(
[
(x - self.K[0, 2] + 0.5) / self.K[0, 0],
(y - self.K[1, 2] + 0.5) / self.K[1, 1],
(x - self.K[0, 2] + 0.5)
/ self.K[0, 0]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
(y - self.K[1, 2] + 0.5)
/ self.K[1, 1]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
],
dim=-1,
),
(0, 1),
value=1,
) # [num_rays, 3]
camera_dirs[..., [1, 2]] *= -1 # opengl format
# [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
......
......@@ -5,7 +5,7 @@ import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from datasets.nerf_synthetic import SubjectLoader, namedtuple_map
from datasets.nerf_synthetic import Rays, SubjectLoader, namedtuple_map
from radiance_fields.ngp import NGPradianceField
from nerfacc import OccupancyField, volumetric_rendering
......@@ -67,10 +67,10 @@ if __name__ == "__main__":
# setup dataset
train_dataset = SubjectLoader(
subject_id="lego",
subject_id="mic",
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="train",
num_rays=4096,
split="trainval",
num_rays=409600,
)
train_dataset.images = train_dataset.images.to(device)
......@@ -85,7 +85,7 @@ if __name__ == "__main__":
)
test_dataset = SubjectLoader(
subject_id="lego",
subject_id="mic",
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="test",
num_rays=None,
......@@ -144,12 +144,27 @@ if __name__ == "__main__":
occ_eval_fn=occ_eval_fn, aabb=scene_aabb, resolution=128
).to(device)
render_bkgd = torch.ones(3, device=device)
# training
step = 0
tic = time.time()
data_time = 0
tic_data = time.time()
for epoch in range(400):
weights_image_ids = torch.ones((len(train_dataset.images),), device=device)
weights_xs = torch.ones(
(train_dataset.WIDTH,),
device=device,
)
weights_ys = torch.ones(
(train_dataset.HEIGHT,),
device=device,
)
for epoch in range(40000000):
data = train_dataset[0]
for i in range(len(train_dataset)):
data = train_dataset[i]
data_time += time.time() - tic_data
......@@ -162,53 +177,66 @@ if __name__ == "__main__":
pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device)
# update occupancy grid
occ_field.every_n_step(step)
rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
radiance_field, rays, render_bkgd
# # update occupancy grid
# occ_field.every_n_step(step)
render_est_n_samples = 2**16 * 16 if radiance_field.training else None
volumetric_rendering(
query_fn=radiance_field.forward, # {x, dir} -> {rgb, density}
rays_o=rays.origins,
rays_d=rays.viewdirs,
scene_aabb=occ_field.aabb,
scene_occ_binary=occ_field.occ_grid_binary,
scene_resolution=occ_field.resolution,
render_bkgd=render_bkgd,
render_n_samples=render_n_samples,
render_est_n_samples=render_est_n_samples, # memory control: wrost case
)
num_rays = len(pixels)
num_rays = int(num_rays * (2**16 / float(compact_counter)))
num_rays = int(math.ceil(num_rays / 128.0) * 128)
train_dataset.update_num_rays(num_rays)
# compute loss
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
# rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
# radiance_field, rays, render_bkgd
# )
# num_rays = len(pixels)
# num_rays = int(num_rays * (2**16 / float(compact_counter)))
# num_rays = int(math.ceil(num_rays / 128.0) * 128)
# train_dataset.update_num_rays(num_rays)
# # compute loss
# loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad()
(loss * 128.0).backward()
optimizer.step()
scheduler.step()
# optimizer.zero_grad()
# (loss * 128.0).backward()
# optimizer.step()
# scheduler.step()
if step % 50 == 0:
elapsed_time = time.time() - tic
print(
f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} "
# f"loss={loss:.5f} | "
# f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
# f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} "
)
if step % 35_000 == 0 and step > 0:
# evaluation
radiance_field.eval()
psnrs = []
with torch.no_grad():
for data in tqdm.tqdm(test_dataloader):
# generate rays from data and the gt pixel color
rays = namedtuple_map(lambda x: x.to(device), data["rays"])
pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device)
# rendering
rgb, depth, acc, alive_ray_mask, _, _ = render_image(
radiance_field, rays, render_bkgd
)
mse = F.mse_loss(rgb, pixels)
psnr = -10.0 * torch.log(mse) / np.log(10.0)
psnrs.append(psnr.item())
psnr_avg = sum(psnrs) / len(psnrs)
print(f"evaluation: {psnr_avg=}")
# if step % 35_000 == 0 and step > 0:
# # evaluation
# radiance_field.eval()
# psnrs = []
# with torch.no_grad():
# for data in tqdm.tqdm(test_dataloader):
# # generate rays from data and the gt pixel color
# rays = namedtuple_map(lambda x: x.to(device), data["rays"])
# pixels = data["pixels"].to(device)
# render_bkgd = data["color_bkgd"].to(device)
# # rendering
# rgb, depth, acc, alive_ray_mask, _, _ = render_image(
# radiance_field, rays, render_bkgd
# )
# mse = F.mse_loss(rgb, pixels)
# psnr = -10.0 * torch.log(mse) / np.log(10.0)
# psnrs.append(psnr.item())
# psnr_avg = sum(psnrs) / len(psnrs)
# print(f"evaluation: {psnr_avg=}")
tic_data = time.time()
step += 1
......
......@@ -19,8 +19,8 @@ inline __host__ __device__ void _ray_aabb_intersect(
if (tymin > tymax) __swap(tymin, tymax);
if (tmin > tymax || tymin > tmax){
*near = std::numeric_limits<scalar_t>::max();
*far = std::numeric_limits<scalar_t>::max();
*near = 1e10;
*far = 1e10;
return;
}
......@@ -32,8 +32,8 @@ inline __host__ __device__ void _ray_aabb_intersect(
if (tzmin > tzmax) __swap(tzmin, tzmax);
if (tmin > tzmax || tzmin > tmax){
*near = std::numeric_limits<scalar_t>::max();
*far = std::numeric_limits<scalar_t>::max();
*near = 1e10;
*far = 1e10;
return;
}
......@@ -103,7 +103,7 @@ std::vector<torch::Tensor> ray_aabb_intersect(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "ray_aabb_intersect",
([&] {
kernel_ray_aabb_intersect<scalar_t><<<blocks, threads>>>(
kernel_ray_aabb_intersect<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
N,
rays_o.data_ptr<scalar_t>(),
rays_d.data_ptr<scalar_t>(),
......
......@@ -16,7 +16,7 @@ std::vector<torch::Tensor> ray_marching(
const torch::Tensor t_max,
// density grid
const torch::Tensor aabb,
const torch::Tensor resolution,
const pybind11::list resolution,
const torch::Tensor occ_binary,
// sampling
const int max_total_samples,
......
#include <pybind11/pybind11.h>
#include "include/helpers_cuda.h"
inline __device__ int cascaded_grid_idx_at(
const float x, const float y, const float z,
const int* resolution, const float* aabb
const int resx, const int resy, const int resz,
const float* aabb
) {
// TODO(ruilongli): if the x, y, z is outside the aabb, it will be clipped into aabb!!! We should just return false
int ix = (int)(((x - aabb[0]) / (aabb[3] - aabb[0])) * resolution[0]);
int iy = (int)(((y - aabb[1]) / (aabb[4] - aabb[1])) * resolution[1]);
int iz = (int)(((z - aabb[2]) / (aabb[5] - aabb[2])) * resolution[2]);
ix = __clamp(ix, 0, resolution[0]-1);
iy = __clamp(iy, 0, resolution[1]-1);
iz = __clamp(iz, 0, resolution[2]-1);
int idx = ix * resolution[1] * resolution[2] + iy * resolution[2] + iz;
int ix = (int)(((x - aabb[0]) / (aabb[3] - aabb[0])) * resx);
int iy = (int)(((y - aabb[1]) / (aabb[4] - aabb[1])) * resy);
int iz = (int)(((z - aabb[2]) / (aabb[5] - aabb[2])) * resz);
ix = __clamp(ix, 0, resx-1);
iy = __clamp(iy, 0, resy-1);
iz = __clamp(iz, 0, resz-1);
int idx = ix * resx * resy + iy * resz + iz;
return idx;
}
inline __device__ bool grid_occupied_at(
const float x, const float y, const float z,
const int* resolution, const float* aabb, const bool* occ_binary
const int resx, const int resy, const int resz,
const float* aabb, const bool* occ_binary
) {
int idx = cascaded_grid_idx_at(x, y, z, resolution, aabb);
int idx = cascaded_grid_idx_at(x, y, z, resx, resy, resz, aabb);
return occ_binary[idx];
}
......@@ -28,13 +31,13 @@ inline __device__ float distance_to_next_voxel(
float x, float y, float z,
float dir_x, float dir_y, float dir_z,
float idir_x, float idir_y, float idir_z,
const int* resolution
const int resx, const int resy, const int resz
) { // dda like step
// TODO: warning: expression has no effect?
x, y, z = resolution[0] * x, resolution[1] * y, resolution[2] * z;
float tx = ((floorf(x + 0.5f + 0.5f * __sign(dir_x)) - x) * idir_x) / resolution[0];
float ty = ((floorf(y + 0.5f + 0.5f * __sign(dir_y)) - y) * idir_y) / resolution[1];
float tz = ((floorf(z + 0.5f + 0.5f * __sign(dir_z)) - z) * idir_z) / resolution[2];
x, y, z = resx * x, resy * y, resz * z;
float tx = ((floorf(x + 0.5f + 0.5f * __sign(dir_x)) - x) * idir_x) / resx;
float ty = ((floorf(y + 0.5f + 0.5f * __sign(dir_y)) - y) * idir_y) / resy;
float tz = ((floorf(z + 0.5f + 0.5f * __sign(dir_z)) - z) * idir_z) / resz;
float t = min(min(tx, ty), tz);
return fmaxf(t, 0.0f);
}
......@@ -44,10 +47,11 @@ inline __device__ float advance_to_next_voxel(
float x, float y, float z,
float dir_x, float dir_y, float dir_z,
float idir_x, float idir_y, float idir_z,
const int* resolution, float dt_min) {
const int resx, const int resy, const int resz,
float dt_min) {
// Regular stepping (may be slower but matches non-empty space)
float t_target = t + distance_to_next_voxel(
x, y, z, dir_x, dir_y, dir_z, idir_x, idir_y, idir_z, resolution
x, y, z, dir_x, dir_y, dir_z, idir_x, idir_y, idir_z, resx, resy, resz
);
do {
t += dt_min;
......@@ -65,7 +69,9 @@ __global__ void kernel_raymarching(
const float* t_max, // shape (n_rays,)
// density grid
const float* aabb, // [min_x, min_y, min_z, max_x, max_y, max_y]
const int* resolution, // [reso_x, reso_y, reso_z]
const int resx,
const int resy,
const int resz,
const bool* occ_binary, // shape (reso_x, reso_y, reso_z)
// sampling
const int max_total_samples,
......@@ -83,102 +89,102 @@ __global__ void kernel_raymarching(
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
// // locate
// rays_o += i * 3;
// rays_d += i * 3;
// t_min += i;
// t_max += i;
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float near = t_min[0], far = t_max[0];
// const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
// const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
// const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
// const float near = t_min[0], far = t_max[0];
uint32_t ray_idx, base, marching_samples;
uint32_t j;
float t0, t1, t_mid;
// uint32_t ray_idx, base, marching_samples;
// uint32_t j;
// float t0, t1, t_mid;
// first pass to compute an accurate number of steps
j = 0;
t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
// // first pass to compute an accurate number of steps
// j = 0;
// t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
// t1 = t0 + dt;
// t_mid = (t0 + t1) * 0.5f;
while (t_mid < far && j < max_per_ray_samples) {
// current center
const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy;
const float z = oz + t_mid * dz;
// while (t_mid < far && j < max_per_ray_samples) {
// // current center
// const float x = ox + t_mid * dx;
// const float y = oy + t_mid * dy;
// const float z = oz + t_mid * dz;
if (grid_occupied_at(x, y, z, resolution, aabb, occ_binary)) {
++j;
// march to next sample
t0 = t1;
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
}
else {
// march to next sample
t_mid = advance_to_next_voxel(
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resolution, dt
);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
}
}
if (j == 0) return;
// if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
// ++j;
// // march to next sample
// t0 = t1;
// t1 = t0 + dt;
// t_mid = (t0 + t1) * 0.5f;
// }
// else {
// // march to next sample
// t_mid = advance_to_next_voxel(
// t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
// );
// t0 = t_mid - dt * 0.5f;
// t1 = t_mid + dt * 0.5f;
// }
// }
// if (j == 0) return;
marching_samples = j;
base = atomicAdd(steps_counter, marching_samples);
if (base + marching_samples > max_total_samples) return;
ray_idx = atomicAdd(rays_counter, 1);
// marching_samples = j;
// base = atomicAdd(steps_counter, marching_samples);
// if (base + marching_samples > max_total_samples) return;
// ray_idx = atomicAdd(rays_counter, 1);
// locate
frustum_origins += base * 3;
frustum_dirs += base * 3;
frustum_starts += base;
frustum_ends += base;
// // locate
// frustum_origins += base * 3;
// frustum_dirs += base * 3;
// frustum_starts += base;
// frustum_ends += base;
// Second round
j = 0;
t0 = near;
t1 = t0 + dt;
t_mid = (t0 + t1) / 2.;
// // Second round
// j = 0;
// t0 = near;
// t1 = t0 + dt;
// t_mid = (t0 + t1) / 2.;
while (t_mid < far && j < marching_samples) {
// current center
const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy;
const float z = oz + t_mid * dz;
// while (t_mid < far && j < marching_samples) {
// // current center
// const float x = ox + t_mid * dx;
// const float y = oy + t_mid * dy;
// const float z = oz + t_mid * dz;
if (grid_occupied_at(x, y, z, resolution, aabb, occ_binary)) {
frustum_origins[j * 3 + 0] = ox;
frustum_origins[j * 3 + 1] = oy;
frustum_origins[j * 3 + 2] = oz;
frustum_dirs[j * 3 + 0] = dx;
frustum_dirs[j * 3 + 1] = dy;
frustum_dirs[j * 3 + 2] = dz;
frustum_starts[j] = t0;
frustum_ends[j] = t1;
++j;
// march to next sample
t0 = t1;
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
}
else {
// march to next sample
t_mid = advance_to_next_voxel(
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resolution, dt
);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
}
}
// if (grid_occupied_at(x, y, z, resx, resy, resz, aabb, occ_binary)) {
// frustum_origins[j * 3 + 0] = ox;
// frustum_origins[j * 3 + 1] = oy;
// frustum_origins[j * 3 + 2] = oz;
// frustum_dirs[j * 3 + 0] = dx;
// frustum_dirs[j * 3 + 1] = dy;
// frustum_dirs[j * 3 + 2] = dz;
// frustum_starts[j] = t0;
// frustum_ends[j] = t1;
// ++j;
// // march to next sample
// t0 = t1;
// t1 = t0 + dt;
// t_mid = (t0 + t1) * 0.5f;
// }
// else {
// // march to next sample
// t_mid = advance_to_next_voxel(
// t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resx, resy, resz, dt
// );
// t0 = t_mid - dt * 0.5f;
// t1 = t_mid + dt * 0.5f;
// }
// }
packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d}
packed_info[ray_idx * 3 + 1] = base; // point idx start.
packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
// packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d}
// packed_info[ray_idx * 3 + 1] = base; // point idx start.
// packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
return;
}
......@@ -220,67 +226,69 @@ std::vector<torch::Tensor> ray_marching(
const torch::Tensor t_max,
// density grid
const torch::Tensor aabb,
const torch::Tensor resolution,
const pybind11::list resolution,
const torch::Tensor occ_binary,
// sampling
const int max_total_samples,
const int max_per_ray_samples,
const float dt
) {
DEVICE_GUARD(rays_o);
// DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
CHECK_INPUT(aabb);
CHECK_INPUT(resolution);
CHECK_INPUT(occ_binary);
// CHECK_INPUT(rays_o);
// CHECK_INPUT(rays_d);
// CHECK_INPUT(t_min);
// CHECK_INPUT(t_max);
// CHECK_INPUT(aabb);
// CHECK_INPUT(occ_binary);
const int n_rays = rays_o.size(0);
// const int n_rays = rays_o.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// // const int threads = 256;
// // const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor steps_counter = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32));
torch::Tensor rays_counter = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32));
// // helper counter
// torch::Tensor steps_counter = torch::zeros(
// {1}, rays_o.options().dtype(torch::kInt32));
// torch::Tensor rays_counter = torch::zeros(
// {1}, rays_o.options().dtype(torch::kInt32));
// output frustum samples
torch::Tensor packed_info = torch::zeros(
{n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples
torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options());
torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options());
torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options());
torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
// // output frustum samples
// torch::Tensor packed_info = torch::zeros(
// {n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples
// torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options());
// torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options());
// torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options());
// torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
kernel_raymarching<<<blocks, threads>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// density grid
aabb.data_ptr<float>(),
resolution.data_ptr<int>(),
occ_binary.data_ptr<bool>(),
// sampling
max_total_samples,
max_per_ray_samples,
dt,
// writable helpers
steps_counter.data_ptr<int>(), // total samples.
rays_counter.data_ptr<int>(), // total rays.
packed_info.data_ptr<int>(),
frustum_origins.data_ptr<float>(),
frustum_dirs.data_ptr<float>(),
frustum_starts.data_ptr<float>(),
frustum_ends.data_ptr<float>()
);
// kernel_raymarching<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// // rays
// n_rays,
// rays_o.data_ptr<float>(),
// rays_d.data_ptr<float>(),
// t_min.data_ptr<float>(),
// t_max.data_ptr<float>(),
// // density grid
// aabb.data_ptr<float>(),
// resolution[0].cast<int>(),
// resolution[1].cast<int>(),
// resolution[2].cast<int>(),
// occ_binary.data_ptr<bool>(),
// // sampling
// max_total_samples,
// max_per_ray_samples,
// dt,
// // writable helpers
// steps_counter.data_ptr<int>(), // total samples.
// rays_counter.data_ptr<int>(), // total rays.
// packed_info.data_ptr<int>(),
// frustum_origins.data_ptr<float>(),
// frustum_dirs.data_ptr<float>(),
// frustum_starts.data_ptr<float>(),
// frustum_ends.data_ptr<float>()
// );
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
// return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
return {};
}
......@@ -22,7 +22,8 @@ def volumetric_rendering(
device = rays_o.device
if render_bkgd is None:
render_bkgd = torch.ones(3, device=device)
scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
# scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous()
......@@ -40,18 +41,17 @@ def volumetric_rendering(
)
with torch.no_grad():
# TODO: avoid clamp here. kinda stupid
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb)
t_min = torch.clamp(t_min, max=1e10)
t_max = torch.clamp(t_max, max=1e10)
# t_min = torch.clamp(t_min, max=1e10)
# t_max = torch.clamp(t_max, max=1e10)
(
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
steps_counter,
# packed_info,
# frustum_origins,
# frustum_dirs,
# frustum_starts,
# frustum_ends,
# steps_counter,
) = ray_marching(
# rays
rays_o,
......@@ -68,43 +68,43 @@ def volumetric_rendering(
render_step_size,
)
# squeeze valid samples
total_samples = max(packed_info[:, -1].sum(), 1)
total_samples = int(math.ceil(total_samples / 128.0)) * 128
frustum_origins = frustum_origins[:total_samples]
frustum_dirs = frustum_dirs[:total_samples]
frustum_starts = frustum_starts[:total_samples]
frustum_ends = frustum_ends[:total_samples]
# # squeeze valid samples
# total_samples = max(packed_info[:, -1].sum(), 1)
# total_samples = int(math.ceil(total_samples / 128.0)) * 128
# frustum_origins = frustum_origins[:total_samples]
# frustum_dirs = frustum_dirs[:total_samples]
# frustum_starts = frustum_starts[:total_samples]
# frustum_ends = frustum_ends[:total_samples]
frustum_positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
# frustum_positions = (
# frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
# )
query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
rgbs, densities = query_results[0], query_results[1]
# query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
# rgbs, densities = query_results[0], query_results[1]
(
accumulated_weight,
accumulated_depth,
accumulated_color,
alive_ray_mask,
compact_steps_counter,
) = VolumeRenderer.apply(
packed_info,
frustum_starts,
frustum_ends,
densities.contiguous(),
rgbs.contiguous(),
)
# (
# accumulated_weight,
# accumulated_depth,
# accumulated_color,
# alive_ray_mask,
# compact_steps_counter,
# ) = VolumeRenderer.apply(
# packed_info,
# frustum_starts,
# frustum_ends,
# densities.contiguous(),
# rgbs.contiguous(),
# )
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
# accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
# accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
return (
accumulated_color,
accumulated_depth,
accumulated_weight,
alive_ray_mask,
steps_counter,
compact_steps_counter,
)
# return (
# accumulated_color,
# accumulated_depth,
# accumulated_weight,
# alive_ray_mask,
# steps_counter,
# compact_steps_counter,
# )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment