Commit 65bebd64 authored by Ruilong Li's avatar Ruilong Li
Browse files

training recepi

parent 630f2596
...@@ -12,12 +12,15 @@ python examples/trainval.py ...@@ -12,12 +12,15 @@ python examples/trainval.py
Tested with the default settings on the Lego test set. Tested with the default settings on the Lego test set.
| Model | Split | PSNR | Train Time | Test Speed | GPU | | Model | Split | PSNR | Train Time | Test Speed | GPU | Train Memory |
| - | - | - | - | - | - | | - | - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 | | instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 | | instant-ngp (code) | train (35k steps) | 36.08 | 308 sec | 55.32 fps | TITAN RTX | 1734MB |
| ours | train (30K steps) | 34.40 | 296 sec | 6.2 fps | TITAN RTX | | torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | trainval (30K steps) | 35.42 | 291 sec | 6.4 fps | TITAN RTX | | ours | train (30K steps) | 34.40 | 296 sec | 6.2 fps | TITAN RTX |
| ours | trainval (30K steps) | 35.42 | 291 sec | 6.4 fps | TITAN RTX |
| ours (2**16 samples w preload) | trainval (35K steps) | 36.18 | 385 sec | 8.3 fps | TITAN RTX |
| ours (2**16 samples w preload) | train (35K steps) | 35.03 | 383 sec | 8.0 fps | TITAN RTX |
## Tips: ## Tips:
......
...@@ -147,23 +147,28 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -147,23 +147,28 @@ class SubjectLoader(torch.utils.data.Dataset):
**{k: v for k, v in data.items() if k not in ["rgba", "rays"]}, **{k: v for k, v in data.items() if k not in ["rgba", "rays"]},
} }
def update_num_rays(self, num_rays):
self.num_rays = num_rays
def fetch_data(self, index): def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches).""" """Fetch the data (it maybe cached for multiple batches)."""
num_rays = self.num_rays
if self.training: if self.training:
if self.batch_over_images: if self.batch_over_images:
image_id = torch.randint( image_id = torch.randint(
0, 0,
len(self.images), len(self.images),
size=(self.num_rays,), size=(num_rays,),
device=self.images.device, device=self.images.device,
) )
else: else:
image_id = [index] image_id = [index]
x = torch.randint( x = torch.randint(
0, self.WIDTH, size=(self.num_rays,), device=self.images.device 0, self.WIDTH, size=(num_rays,), device=self.images.device
) )
y = torch.randint( y = torch.randint(
0, self.HEIGHT, size=(self.num_rays,), device=self.images.device 0, self.HEIGHT, size=(num_rays,), device=self.images.device
) )
else: else:
image_id = [index] image_id = [index]
...@@ -197,9 +202,9 @@ class SubjectLoader(torch.utils.data.Dataset): ...@@ -197,9 +202,9 @@ class SubjectLoader(torch.utils.data.Dataset):
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True) viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
if self.training: if self.training:
origins = torch.reshape(origins, (self.num_rays, 3)) origins = torch.reshape(origins, (num_rays, 3))
viewdirs = torch.reshape(viewdirs, (self.num_rays, 3)) viewdirs = torch.reshape(viewdirs, (num_rays, 3))
rgba = torch.reshape(rgba, (self.num_rays, 4)) rgba = torch.reshape(rgba, (num_rays, 4))
else: else:
origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3)) origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3))
viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3)) viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3))
......
...@@ -32,9 +32,10 @@ def render_image(radiance_field, rays, render_bkgd): ...@@ -32,9 +32,10 @@ def render_image(radiance_field, rays, render_bkgd):
num_rays, _ = rays_shape num_rays, _ = rays_shape
results = [] results = []
chunk = torch.iinfo(torch.int32).max if radiance_field.training else 81920 chunk = torch.iinfo(torch.int32).max if radiance_field.training else 81920
render_est_n_samples = 2**16 * 16 if radiance_field.training else None
for i in range(0, num_rays, chunk): for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_color, chunk_depth, chunk_weight, alive_ray_mask, = volumetric_rendering( chunk_results = volumetric_rendering(
query_fn=radiance_field.forward, # {x, dir} -> {rgb, density} query_fn=radiance_field.forward, # {x, dir} -> {rgb, density}
rays_o=chunk_rays.origins, rays_o=chunk_rays.origins,
rays_d=chunk_rays.viewdirs, rays_d=chunk_rays.viewdirs,
...@@ -43,14 +44,19 @@ def render_image(radiance_field, rays, render_bkgd): ...@@ -43,14 +44,19 @@ def render_image(radiance_field, rays, render_bkgd):
scene_resolution=occ_field.resolution, scene_resolution=occ_field.resolution,
render_bkgd=render_bkgd, render_bkgd=render_bkgd,
render_n_samples=render_n_samples, render_n_samples=render_n_samples,
render_est_n_samples=render_est_n_samples, # memory control: wrost case
) )
results.append([chunk_color, chunk_depth, chunk_weight, alive_ray_mask]) results.append(chunk_results)
rgb, depth, acc, alive_ray_mask = [torch.cat(r, dim=0) for r in zip(*results)] rgb, depth, acc, alive_ray_mask, counter, compact_counter = [
torch.cat(r, dim=0) for r in zip(*results)
]
return ( return (
rgb.view((*rays_shape[:-1], -1)), rgb.view((*rays_shape[:-1], -1)),
depth.view((*rays_shape[:-1], -1)), depth.view((*rays_shape[:-1], -1)),
acc.view((*rays_shape[:-1], -1)), acc.view((*rays_shape[:-1], -1)),
alive_ray_mask.view(*rays_shape[:-1]), alive_ray_mask.view(*rays_shape[:-1]),
counter.sum(),
compact_counter.sum(),
) )
...@@ -63,17 +69,18 @@ if __name__ == "__main__": ...@@ -63,17 +69,18 @@ if __name__ == "__main__":
train_dataset = SubjectLoader( train_dataset = SubjectLoader(
subject_id="lego", subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/", root_fp="/home/ruilongli/data/nerf_synthetic/",
split="trainval", split="train",
num_rays=8192, num_rays=4096,
) )
# train_dataset.images = train_dataset.images.to(device)
# train_dataset.camtoworlds = train_dataset.camtoworlds.to(device) train_dataset.images = train_dataset.images.to(device)
# train_dataset.K = train_dataset.K.to(device) train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
train_dataset.K = train_dataset.K.to(device)
train_dataloader = torch.utils.data.DataLoader( train_dataloader = torch.utils.data.DataLoader(
train_dataset, train_dataset,
num_workers=4, num_workers=0,
batch_size=None, batch_size=None,
persistent_workers=True, # persistent_workers=True,
shuffle=True, shuffle=True,
) )
...@@ -83,12 +90,12 @@ if __name__ == "__main__": ...@@ -83,12 +90,12 @@ if __name__ == "__main__":
split="test", split="test",
num_rays=None, num_rays=None,
) )
# test_dataset.images = test_dataset.images.to(device) test_dataset.images = test_dataset.images.to(device)
# test_dataset.camtoworlds = test_dataset.camtoworlds.to(device) test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
# test_dataset.K = test_dataset.K.to(device) test_dataset.K = test_dataset.K.to(device)
test_dataloader = torch.utils.data.DataLoader( test_dataloader = torch.utils.data.DataLoader(
test_dataset, test_dataset,
num_workers=4, num_workers=0,
batch_size=None, batch_size=None,
) )
...@@ -107,7 +114,13 @@ if __name__ == "__main__": ...@@ -107,7 +114,13 @@ if __name__ == "__main__":
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples (scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
) )
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=3e-3, eps=1e-15) optimizer = torch.optim.Adam(
radiance_field.parameters(),
lr=1e-2,
# betas=(0.9, 0.99),
eps=1e-15,
# weight_decay=1e-6,
)
scheduler = torch.optim.lr_scheduler.MultiStepLR( scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[20000, 30000], gamma=0.1 optimizer, milestones=[20000, 30000], gamma=0.1
) )
...@@ -136,11 +149,11 @@ if __name__ == "__main__": ...@@ -136,11 +149,11 @@ if __name__ == "__main__":
tic = time.time() tic = time.time()
data_time = 0 data_time = 0
tic_data = time.time() tic_data = time.time()
for epoch in range(300): for epoch in range(400):
for data in train_dataloader: for i in range(len(train_dataset)):
data = train_dataset[i]
data_time += time.time() - tic_data data_time += time.time() - tic_data
step += 1 if step > 35_000:
if step > 30_000:
print("training stops") print("training stops")
exit() exit()
...@@ -152,25 +165,32 @@ if __name__ == "__main__": ...@@ -152,25 +165,32 @@ if __name__ == "__main__":
# update occupancy grid # update occupancy grid
occ_field.every_n_step(step) occ_field.every_n_step(step)
rgb, depth, acc, alive_ray_mask = render_image( rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
radiance_field, rays, render_bkgd radiance_field, rays, render_bkgd
) )
num_rays = len(pixels)
num_rays = int(num_rays * (2**16 / float(compact_counter)))
num_rays = int(math.ceil(num_rays / 128.0) * 128)
train_dataset.update_num_rays(num_rays)
# compute loss # compute loss
loss = F.mse_loss(rgb, pixels) loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() (loss * 128.0).backward()
optimizer.step() optimizer.step()
scheduler.step() scheduler.step()
if step % 50 == 0: if step % 50 == 0:
elapsed_time = time.time() - tic elapsed_time = time.time() - tic
print( print(
f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | loss={loss:.5f}" f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} "
) )
if step % 30_000 == 0 and step > 0: if step % 35_000 == 0 and step > 0:
# evaluation # evaluation
radiance_field.eval() radiance_field.eval()
psnrs = [] psnrs = []
...@@ -181,7 +201,7 @@ if __name__ == "__main__": ...@@ -181,7 +201,7 @@ if __name__ == "__main__":
pixels = data["pixels"].to(device) pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device) render_bkgd = data["color_bkgd"].to(device)
# rendering # rendering
rgb, depth, acc, alive_ray_mask = render_image( rgb, depth, acc, alive_ray_mask, _, _ = render_image(
radiance_field, rays, render_bkgd radiance_field, rays, render_bkgd
) )
mse = F.mse_loss(rgb, pixels) mse = F.mse_loss(rgb, pixels)
...@@ -191,6 +211,8 @@ if __name__ == "__main__": ...@@ -191,6 +211,8 @@ if __name__ == "__main__":
print(f"evaluation: {psnr_avg=}") print(f"evaluation: {psnr_avg=}")
tic_data = time.time() tic_data = time.time()
step += 1
# "train" # "train"
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026 # elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s) # evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
...@@ -210,3 +232,7 @@ if __name__ == "__main__": ...@@ -210,3 +232,7 @@ if __name__ == "__main__":
# "trainval" batch_over_images=True, schedule # "trainval" batch_over_images=True, schedule
# elapsed_time=291.42s (data=52.82s) | step=30000 | loss=0.00020 # elapsed_time=291.42s (data=52.82s) | step=30000 | loss=0.00020
# evaluation: psnr_avg=35.41630497932434 (6.40 it/s) # evaluation: psnr_avg=35.41630497932434 (6.40 it/s)
# "trainval" batch_over_images=True, schedule 2**18
# evaluation: psnr_avg=36.24 (6.75 it/s)
...@@ -20,6 +20,7 @@ class VolumeRenderer(torch.autograd.Function): ...@@ -20,6 +20,7 @@ class VolumeRenderer(torch.autograd.Function):
accumulated_depth, accumulated_depth,
accumulated_color, accumulated_color,
mask, mask,
steps_counter,
) = volumetric_rendering_forward(packed_info, starts, ends, sigmas, rgbs) ) = volumetric_rendering_forward(packed_info, starts, ends, sigmas, rgbs)
ctx.save_for_backward( ctx.save_for_backward(
accumulated_weight, accumulated_weight,
...@@ -31,11 +32,19 @@ class VolumeRenderer(torch.autograd.Function): ...@@ -31,11 +32,19 @@ class VolumeRenderer(torch.autograd.Function):
sigmas, sigmas,
rgbs, rgbs,
) )
return accumulated_weight, accumulated_depth, accumulated_color, mask return (
accumulated_weight,
accumulated_depth,
accumulated_color,
mask,
steps_counter,
)
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, grad_weight, grad_depth, grad_color, _grad_mask): def backward(
ctx, grad_weight, grad_depth, grad_color, _grad_mask, _grad_steps_counter
):
( (
accumulated_weight, accumulated_weight,
accumulated_depth, accumulated_depth,
......
...@@ -281,6 +281,6 @@ std::vector<torch::Tensor> ray_marching( ...@@ -281,6 +281,6 @@ std::vector<torch::Tensor> ray_marching(
frustum_ends.data_ptr<float>() frustum_ends.data_ptr<float>()
); );
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends}; return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
} }
...@@ -13,7 +13,9 @@ __global__ void volumetric_rendering_forward_kernel( ...@@ -13,7 +13,9 @@ __global__ void volumetric_rendering_forward_kernel(
scalar_t* accumulated_weight, // output scalar_t* accumulated_weight, // output
scalar_t* accumulated_depth, // output scalar_t* accumulated_depth, // output
scalar_t* accumulated_color, // output scalar_t* accumulated_color, // output
bool* mask // output bool* mask, // output
// writable helpers
int* steps_counter
) { ) {
CUDA_GET_THREAD_ID(thread_id, n_rays); CUDA_GET_THREAD_ID(thread_id, n_rays);
...@@ -54,6 +56,7 @@ __global__ void volumetric_rendering_forward_kernel( ...@@ -54,6 +56,7 @@ __global__ void volumetric_rendering_forward_kernel(
T *= (1.f - alpha); T *= (1.f - alpha);
} }
mask[0] = true; mask[0] = true;
atomicAdd(steps_counter, j);
} }
...@@ -178,6 +181,10 @@ std::vector<torch::Tensor> volumetric_rendering_forward( ...@@ -178,6 +181,10 @@ std::vector<torch::Tensor> volumetric_rendering_forward(
const int threads = 256; const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads); const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor steps_counter = torch::zeros(
{1}, rgbs.options().dtype(torch::kInt32));
// outputs // outputs
torch::Tensor accumulated_weight = torch::zeros({n_rays, 1}, sigmas.options()); torch::Tensor accumulated_weight = torch::zeros({n_rays, 1}, sigmas.options());
torch::Tensor accumulated_depth = torch::zeros({n_rays, 1}, sigmas.options()); torch::Tensor accumulated_depth = torch::zeros({n_rays, 1}, sigmas.options());
...@@ -199,11 +206,12 @@ std::vector<torch::Tensor> volumetric_rendering_forward( ...@@ -199,11 +206,12 @@ std::vector<torch::Tensor> volumetric_rendering_forward(
accumulated_weight.data_ptr<scalar_t>(), accumulated_weight.data_ptr<scalar_t>(),
accumulated_depth.data_ptr<scalar_t>(), accumulated_depth.data_ptr<scalar_t>(),
accumulated_color.data_ptr<scalar_t>(), accumulated_color.data_ptr<scalar_t>(),
mask.data_ptr<bool>() mask.data_ptr<bool>(),
steps_counter.data_ptr<int>()
); );
})); }));
return {accumulated_weight, accumulated_depth, accumulated_color, mask}; return {accumulated_weight, accumulated_depth, accumulated_color, mask, steps_counter};
} }
......
...@@ -34,7 +34,7 @@ def volumetric_rendering( ...@@ -34,7 +34,7 @@ def volumetric_rendering(
if render_est_n_samples is None: if render_est_n_samples is None:
render_total_samples = n_rays * render_n_samples render_total_samples = n_rays * render_n_samples
else: else:
render_total_samples = n_rays * render_est_n_samples render_total_samples = render_est_n_samples
render_step_size = ( render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples (scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
) )
...@@ -51,6 +51,7 @@ def volumetric_rendering( ...@@ -51,6 +51,7 @@ def volumetric_rendering(
frustum_dirs, frustum_dirs,
frustum_starts, frustum_starts,
frustum_ends, frustum_ends,
steps_counter,
) = ray_marching( ) = ray_marching(
# rays # rays
rays_o, rays_o,
...@@ -69,6 +70,7 @@ def volumetric_rendering( ...@@ -69,6 +70,7 @@ def volumetric_rendering(
# squeeze valid samples # squeeze valid samples
total_samples = max(packed_info[:, -1].sum(), 1) total_samples = max(packed_info[:, -1].sum(), 1)
total_samples = int(math.ceil(total_samples / 128.0)) * 128
frustum_origins = frustum_origins[:total_samples] frustum_origins = frustum_origins[:total_samples]
frustum_dirs = frustum_dirs[:total_samples] frustum_dirs = frustum_dirs[:total_samples]
frustum_starts = frustum_starts[:total_samples] frustum_starts = frustum_starts[:total_samples]
...@@ -86,6 +88,7 @@ def volumetric_rendering( ...@@ -86,6 +88,7 @@ def volumetric_rendering(
accumulated_depth, accumulated_depth,
accumulated_color, accumulated_color,
alive_ray_mask, alive_ray_mask,
compact_steps_counter,
) = VolumeRenderer.apply( ) = VolumeRenderer.apply(
packed_info, packed_info,
frustum_starts, frustum_starts,
...@@ -97,4 +100,11 @@ def volumetric_rendering( ...@@ -97,4 +100,11 @@ def volumetric_rendering(
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None]) accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight) accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
return accumulated_color, accumulated_depth, accumulated_weight, alive_ray_mask return (
accumulated_color,
accumulated_depth,
accumulated_weight,
alive_ray_mask,
steps_counter,
compact_steps_counter,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment