Commit 65bebd64 authored by Ruilong Li's avatar Ruilong Li
Browse files

training recepi

parent 630f2596
......@@ -12,12 +12,15 @@ python examples/trainval.py
Tested with the default settings on the Lego test set.
| Model | Split | PSNR | Train Time | Test Speed | GPU |
| - | - | - | - | - | - |
| Model | Split | PSNR | Train Time | Test Speed | GPU | Train Memory |
| - | - | - | - | - | - | - |
| instant-ngp (paper) | trainval? | 36.39 | - | - | 3090 |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 34.40 | 296 sec | 6.2 fps | TITAN RTX |
| ours | trainval (30K steps) | 35.42 | 291 sec | 6.4 fps | TITAN RTX |
| instant-ngp (code) | train (35k steps) | 36.08 | 308 sec | 55.32 fps | TITAN RTX | 1734MB |
| torch-ngp (`-O`) | train (30K steps) | 34.15 | 310 sec | 7.8 fps | V100 |
| ours | train (30K steps) | 34.40 | 296 sec | 6.2 fps | TITAN RTX |
| ours | trainval (30K steps) | 35.42 | 291 sec | 6.4 fps | TITAN RTX |
| ours (2**16 samples w preload) | trainval (35K steps) | 36.18 | 385 sec | 8.3 fps | TITAN RTX |
| ours (2**16 samples w preload) | train (35K steps) | 35.03 | 383 sec | 8.0 fps | TITAN RTX |
## Tips:
......
......@@ -147,23 +147,28 @@ class SubjectLoader(torch.utils.data.Dataset):
**{k: v for k, v in data.items() if k not in ["rgba", "rays"]},
}
def update_num_rays(self, num_rays):
self.num_rays = num_rays
def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches)."""
num_rays = self.num_rays
if self.training:
if self.batch_over_images:
image_id = torch.randint(
0,
len(self.images),
size=(self.num_rays,),
size=(num_rays,),
device=self.images.device,
)
else:
image_id = [index]
x = torch.randint(
0, self.WIDTH, size=(self.num_rays,), device=self.images.device
0, self.WIDTH, size=(num_rays,), device=self.images.device
)
y = torch.randint(
0, self.HEIGHT, size=(self.num_rays,), device=self.images.device
0, self.HEIGHT, size=(num_rays,), device=self.images.device
)
else:
image_id = [index]
......@@ -197,9 +202,9 @@ class SubjectLoader(torch.utils.data.Dataset):
viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)
if self.training:
origins = torch.reshape(origins, (self.num_rays, 3))
viewdirs = torch.reshape(viewdirs, (self.num_rays, 3))
rgba = torch.reshape(rgba, (self.num_rays, 4))
origins = torch.reshape(origins, (num_rays, 3))
viewdirs = torch.reshape(viewdirs, (num_rays, 3))
rgba = torch.reshape(rgba, (num_rays, 4))
else:
origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3))
viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3))
......
......@@ -32,9 +32,10 @@ def render_image(radiance_field, rays, render_bkgd):
num_rays, _ = rays_shape
results = []
chunk = torch.iinfo(torch.int32).max if radiance_field.training else 81920
render_est_n_samples = 2**16 * 16 if radiance_field.training else None
for i in range(0, num_rays, chunk):
chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
chunk_color, chunk_depth, chunk_weight, alive_ray_mask, = volumetric_rendering(
chunk_results = volumetric_rendering(
query_fn=radiance_field.forward, # {x, dir} -> {rgb, density}
rays_o=chunk_rays.origins,
rays_d=chunk_rays.viewdirs,
......@@ -43,14 +44,19 @@ def render_image(radiance_field, rays, render_bkgd):
scene_resolution=occ_field.resolution,
render_bkgd=render_bkgd,
render_n_samples=render_n_samples,
render_est_n_samples=render_est_n_samples, # memory control: wrost case
)
results.append([chunk_color, chunk_depth, chunk_weight, alive_ray_mask])
rgb, depth, acc, alive_ray_mask = [torch.cat(r, dim=0) for r in zip(*results)]
results.append(chunk_results)
rgb, depth, acc, alive_ray_mask, counter, compact_counter = [
torch.cat(r, dim=0) for r in zip(*results)
]
return (
rgb.view((*rays_shape[:-1], -1)),
depth.view((*rays_shape[:-1], -1)),
acc.view((*rays_shape[:-1], -1)),
alive_ray_mask.view(*rays_shape[:-1]),
counter.sum(),
compact_counter.sum(),
)
......@@ -63,17 +69,18 @@ if __name__ == "__main__":
train_dataset = SubjectLoader(
subject_id="lego",
root_fp="/home/ruilongli/data/nerf_synthetic/",
split="trainval",
num_rays=8192,
split="train",
num_rays=4096,
)
# train_dataset.images = train_dataset.images.to(device)
# train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
# train_dataset.K = train_dataset.K.to(device)
train_dataset.images = train_dataset.images.to(device)
train_dataset.camtoworlds = train_dataset.camtoworlds.to(device)
train_dataset.K = train_dataset.K.to(device)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
num_workers=4,
num_workers=0,
batch_size=None,
persistent_workers=True,
# persistent_workers=True,
shuffle=True,
)
......@@ -83,12 +90,12 @@ if __name__ == "__main__":
split="test",
num_rays=None,
)
# test_dataset.images = test_dataset.images.to(device)
# test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
# test_dataset.K = test_dataset.K.to(device)
test_dataset.images = test_dataset.images.to(device)
test_dataset.camtoworlds = test_dataset.camtoworlds.to(device)
test_dataset.K = test_dataset.K.to(device)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
num_workers=4,
num_workers=0,
batch_size=None,
)
......@@ -107,7 +114,13 @@ if __name__ == "__main__":
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
)
optimizer = torch.optim.Adam(radiance_field.parameters(), lr=3e-3, eps=1e-15)
optimizer = torch.optim.Adam(
radiance_field.parameters(),
lr=1e-2,
# betas=(0.9, 0.99),
eps=1e-15,
# weight_decay=1e-6,
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[20000, 30000], gamma=0.1
)
......@@ -136,11 +149,11 @@ if __name__ == "__main__":
tic = time.time()
data_time = 0
tic_data = time.time()
for epoch in range(300):
for data in train_dataloader:
for epoch in range(400):
for i in range(len(train_dataset)):
data = train_dataset[i]
data_time += time.time() - tic_data
step += 1
if step > 30_000:
if step > 35_000:
print("training stops")
exit()
......@@ -152,25 +165,32 @@ if __name__ == "__main__":
# update occupancy grid
occ_field.every_n_step(step)
rgb, depth, acc, alive_ray_mask = render_image(
rgb, depth, acc, alive_ray_mask, counter, compact_counter = render_image(
radiance_field, rays, render_bkgd
)
num_rays = len(pixels)
num_rays = int(num_rays * (2**16 / float(compact_counter)))
num_rays = int(math.ceil(num_rays / 128.0) * 128)
train_dataset.update_num_rays(num_rays)
# compute loss
loss = F.mse_loss(rgb, pixels)
loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask])
optimizer.zero_grad()
loss.backward()
(loss * 128.0).backward()
optimizer.step()
scheduler.step()
if step % 50 == 0:
elapsed_time = time.time() - tic
print(
f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | loss={loss:.5f}"
f"elapsed_time={elapsed_time:.2f}s (data={data_time:.2f}s) | {step=} | "
f"loss={loss:.5f} | "
f"alive_ray_mask={alive_ray_mask.long().sum():d} | "
f"counter={counter:d} | compact_counter={compact_counter:d} | num_rays={len(pixels):d} "
)
if step % 30_000 == 0 and step > 0:
if step % 35_000 == 0 and step > 0:
# evaluation
radiance_field.eval()
psnrs = []
......@@ -181,7 +201,7 @@ if __name__ == "__main__":
pixels = data["pixels"].to(device)
render_bkgd = data["color_bkgd"].to(device)
# rendering
rgb, depth, acc, alive_ray_mask = render_image(
rgb, depth, acc, alive_ray_mask, _, _ = render_image(
radiance_field, rays, render_bkgd
)
mse = F.mse_loss(rgb, pixels)
......@@ -191,6 +211,8 @@ if __name__ == "__main__":
print(f"evaluation: {psnr_avg=}")
tic_data = time.time()
step += 1
# "train"
# elapsed_time=298.27s (data=60.08s) | step=30000 | loss=0.00026
# evaluation: psnr_avg=33.305334663391115 (6.42 it/s)
......@@ -210,3 +232,7 @@ if __name__ == "__main__":
# "trainval" batch_over_images=True, schedule
# elapsed_time=291.42s (data=52.82s) | step=30000 | loss=0.00020
# evaluation: psnr_avg=35.41630497932434 (6.40 it/s)
# "trainval" batch_over_images=True, schedule 2**18
# evaluation: psnr_avg=36.24 (6.75 it/s)
......@@ -20,6 +20,7 @@ class VolumeRenderer(torch.autograd.Function):
accumulated_depth,
accumulated_color,
mask,
steps_counter,
) = volumetric_rendering_forward(packed_info, starts, ends, sigmas, rgbs)
ctx.save_for_backward(
accumulated_weight,
......@@ -31,11 +32,19 @@ class VolumeRenderer(torch.autograd.Function):
sigmas,
rgbs,
)
return accumulated_weight, accumulated_depth, accumulated_color, mask
return (
accumulated_weight,
accumulated_depth,
accumulated_color,
mask,
steps_counter,
)
@staticmethod
@custom_bwd
def backward(ctx, grad_weight, grad_depth, grad_color, _grad_mask):
def backward(
ctx, grad_weight, grad_depth, grad_color, _grad_mask, _grad_steps_counter
):
(
accumulated_weight,
accumulated_depth,
......
......@@ -281,6 +281,6 @@ std::vector<torch::Tensor> ray_marching(
frustum_ends.data_ptr<float>()
);
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends};
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends, steps_counter};
}
......@@ -13,7 +13,9 @@ __global__ void volumetric_rendering_forward_kernel(
scalar_t* accumulated_weight, // output
scalar_t* accumulated_depth, // output
scalar_t* accumulated_color, // output
bool* mask // output
bool* mask, // output
// writable helpers
int* steps_counter
) {
CUDA_GET_THREAD_ID(thread_id, n_rays);
......@@ -54,6 +56,7 @@ __global__ void volumetric_rendering_forward_kernel(
T *= (1.f - alpha);
}
mask[0] = true;
atomicAdd(steps_counter, j);
}
......@@ -178,6 +181,10 @@ std::vector<torch::Tensor> volumetric_rendering_forward(
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor steps_counter = torch::zeros(
{1}, rgbs.options().dtype(torch::kInt32));
// outputs
torch::Tensor accumulated_weight = torch::zeros({n_rays, 1}, sigmas.options());
torch::Tensor accumulated_depth = torch::zeros({n_rays, 1}, sigmas.options());
......@@ -199,11 +206,12 @@ std::vector<torch::Tensor> volumetric_rendering_forward(
accumulated_weight.data_ptr<scalar_t>(),
accumulated_depth.data_ptr<scalar_t>(),
accumulated_color.data_ptr<scalar_t>(),
mask.data_ptr<bool>()
mask.data_ptr<bool>(),
steps_counter.data_ptr<int>()
);
}));
return {accumulated_weight, accumulated_depth, accumulated_color, mask};
return {accumulated_weight, accumulated_depth, accumulated_color, mask, steps_counter};
}
......
......@@ -34,7 +34,7 @@ def volumetric_rendering(
if render_est_n_samples is None:
render_total_samples = n_rays * render_n_samples
else:
render_total_samples = n_rays * render_est_n_samples
render_total_samples = render_est_n_samples
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
)
......@@ -51,6 +51,7 @@ def volumetric_rendering(
frustum_dirs,
frustum_starts,
frustum_ends,
steps_counter,
) = ray_marching(
# rays
rays_o,
......@@ -69,6 +70,7 @@ def volumetric_rendering(
# squeeze valid samples
total_samples = max(packed_info[:, -1].sum(), 1)
total_samples = int(math.ceil(total_samples / 128.0)) * 128
frustum_origins = frustum_origins[:total_samples]
frustum_dirs = frustum_dirs[:total_samples]
frustum_starts = frustum_starts[:total_samples]
......@@ -86,6 +88,7 @@ def volumetric_rendering(
accumulated_depth,
accumulated_color,
alive_ray_mask,
compact_steps_counter,
) = VolumeRenderer.apply(
packed_info,
frustum_starts,
......@@ -97,4 +100,11 @@ def volumetric_rendering(
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
return accumulated_color, accumulated_depth, accumulated_weight, alive_ray_mask
return (
accumulated_color,
accumulated_depth,
accumulated_weight,
alive_ray_mask,
steps_counter,
compact_steps_counter,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment