Commit efb72132 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

extract camera_difficulty_bin_breaks

Summary: As part of removing Task, move camera difficulty bin breaks from hard code to the top level.

Reviewed By: davnov134

Differential Revision: D37491040

fbshipit-source-id: f2d6775ebc490f6f75020d13f37f6b588cc07a0b
parent 40fb189c
......@@ -27,3 +27,6 @@ solver_args:
max_epochs: 3000
milestones:
- 1000
camera_difficulty_bin_breaks:
- 0.666667
- 0.833334
......@@ -535,7 +535,12 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
and epoch % cfg.test_interval == 0
):
_run_eval(
model, all_train_cameras, dataloaders.test, task, device=device
model,
all_train_cameras,
dataloaders.test,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
)
assert stats.epoch == epoch, "inconsistent stats!"
......@@ -588,7 +593,14 @@ def _eval_and_dump(
if dataloader is None:
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
results = _run_eval(
model,
all_train_cameras,
dataloader,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
)
# add the evaluation epoch to the results
for r in results:
......@@ -615,7 +627,14 @@ def _get_eval_frame_data(frame_data):
return frame_data_for_eval
def _run_eval(model, all_train_cameras, loader, task: Task, device):
def _run_eval(
model,
all_train_cameras,
loader,
task: Task,
camera_difficulty_bin_breaks: Tuple[float, float],
device,
):
"""
Run the evaluation loop on the test dataloader
"""
......@@ -648,7 +667,7 @@ def _run_eval(model, all_train_cameras, loader, task: Task, device):
)
_, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task
per_batch_eval_results, task, camera_difficulty_bin_breaks
)
return category_result["results"]
......@@ -684,6 +703,7 @@ class ExperimentConfig(Configurable):
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
hydra: dict = field(
default_factory=lambda: {
......
......@@ -375,6 +375,9 @@ visdom_port: 8097
visdom_server: http://127.0.0.1
visualize_interval: 1000
clip_grad: 0.0
camera_difficulty_bin_breaks:
- 0.97
- 0.98
hydra:
run:
dir: .
......
......@@ -153,8 +153,13 @@ def evaluate_dbir_for_category(
)
)
if task == Task.SINGLE_SEQUENCE:
camera_difficulty_bin_breaks = 0.97, 0.98
else:
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
category_result_flat, category_result = summarize_nvs_eval_results(
per_batch_eval_results, task
per_batch_eval_results, task, camera_difficulty_bin_breaks
)
return category_result["results"]
......
......@@ -9,7 +9,7 @@ import copy
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
......@@ -407,19 +407,13 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
def _get_camera_difficulty_bin_edges(task: Task):
def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float, float]):
"""
Get the edges of camera difficulty bins.
"""
_eps = 1e-5
if task == Task.MULTI_SEQUENCE:
# TODO: extract those to constants
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
diff_bin_edges[0] = 0.0 - _eps
elif task == Task.SINGLE_SEQUENCE:
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
else:
raise ValueError(f"No such eval task {task}.")
lower, upper = camera_difficulty_bin_breaks
diff_bin_edges = torch.tensor([0.0 - _eps, lower, upper, 1.0 + _eps]).float()
diff_bin_names = ["hard", "medium", "easy"]
return diff_bin_edges, diff_bin_names
......@@ -427,6 +421,7 @@ def _get_camera_difficulty_bin_edges(task: Task):
def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]],
task: Task,
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
):
"""
Compile the per-batch evaluation results `per_batch_eval_results` into
......@@ -435,6 +430,8 @@ def summarize_nvs_eval_results(
Args:
per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task.
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
Returns:
nvs_results_flat: A flattened dict of all aggregate metrics.
......@@ -461,7 +458,9 @@ def summarize_nvs_eval_results(
# init the result database dict
results = []
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task)
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(
camera_difficulty_bin_breaks
)
n_diff_edges = diff_bin_edges.numel()
# add per set averages
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment