Commit efb72132 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

extract camera_difficulty_bin_breaks

Summary: As part of removing Task, move camera difficulty bin breaks from hard code to the top level.

Reviewed By: davnov134

Differential Revision: D37491040

fbshipit-source-id: f2d6775ebc490f6f75020d13f37f6b588cc07a0b
parent 40fb189c
...@@ -27,3 +27,6 @@ solver_args: ...@@ -27,3 +27,6 @@ solver_args:
max_epochs: 3000 max_epochs: 3000
milestones: milestones:
- 1000 - 1000
camera_difficulty_bin_breaks:
- 0.666667
- 0.833334
...@@ -535,7 +535,12 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None: ...@@ -535,7 +535,12 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
and epoch % cfg.test_interval == 0 and epoch % cfg.test_interval == 0
): ):
_run_eval( _run_eval(
model, all_train_cameras, dataloaders.test, task, device=device model,
all_train_cameras,
dataloaders.test,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
) )
assert stats.epoch == epoch, "inconsistent stats!" assert stats.epoch == epoch, "inconsistent stats!"
...@@ -588,7 +593,14 @@ def _eval_and_dump( ...@@ -588,7 +593,14 @@ def _eval_and_dump(
if dataloader is None: if dataloader is None:
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!') raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
results = _run_eval(model, all_train_cameras, dataloader, task, device=device) results = _run_eval(
model,
all_train_cameras,
dataloader,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
)
# add the evaluation epoch to the results # add the evaluation epoch to the results
for r in results: for r in results:
...@@ -615,7 +627,14 @@ def _get_eval_frame_data(frame_data): ...@@ -615,7 +627,14 @@ def _get_eval_frame_data(frame_data):
return frame_data_for_eval return frame_data_for_eval
def _run_eval(model, all_train_cameras, loader, task: Task, device): def _run_eval(
model,
all_train_cameras,
loader,
task: Task,
camera_difficulty_bin_breaks: Tuple[float, float],
device,
):
""" """
Run the evaluation loop on the test dataloader Run the evaluation loop on the test dataloader
""" """
...@@ -648,7 +667,7 @@ def _run_eval(model, all_train_cameras, loader, task: Task, device): ...@@ -648,7 +667,7 @@ def _run_eval(model, all_train_cameras, loader, task: Task, device):
) )
_, category_result = evaluate.summarize_nvs_eval_results( _, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task per_batch_eval_results, task, camera_difficulty_bin_breaks
) )
return category_result["results"] return category_result["results"]
...@@ -684,6 +703,7 @@ class ExperimentConfig(Configurable): ...@@ -684,6 +703,7 @@ class ExperimentConfig(Configurable):
visdom_server: str = "http://127.0.0.1" visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000 visualize_interval: int = 1000
clip_grad: float = 0.0 clip_grad: float = 0.0
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
hydra: dict = field( hydra: dict = field(
default_factory=lambda: { default_factory=lambda: {
......
...@@ -375,6 +375,9 @@ visdom_port: 8097 ...@@ -375,6 +375,9 @@ visdom_port: 8097
visdom_server: http://127.0.0.1 visdom_server: http://127.0.0.1
visualize_interval: 1000 visualize_interval: 1000
clip_grad: 0.0 clip_grad: 0.0
camera_difficulty_bin_breaks:
- 0.97
- 0.98
hydra: hydra:
run: run:
dir: . dir: .
......
...@@ -153,8 +153,13 @@ def evaluate_dbir_for_category( ...@@ -153,8 +153,13 @@ def evaluate_dbir_for_category(
) )
) )
if task == Task.SINGLE_SEQUENCE:
camera_difficulty_bin_breaks = 0.97, 0.98
else:
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
category_result_flat, category_result = summarize_nvs_eval_results( category_result_flat, category_result = summarize_nvs_eval_results(
per_batch_eval_results, task per_batch_eval_results, task, camera_difficulty_bin_breaks
) )
return category_result["results"] return category_result["results"]
......
...@@ -9,7 +9,7 @@ import copy ...@@ -9,7 +9,7 @@ import copy
import warnings import warnings
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -407,19 +407,13 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso ...@@ -407,19 +407,13 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
return ious.topk(k=min(topk, len(ious) - 1)).values.mean() return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
def _get_camera_difficulty_bin_edges(task: Task): def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float, float]):
""" """
Get the edges of camera difficulty bins. Get the edges of camera difficulty bins.
""" """
_eps = 1e-5 _eps = 1e-5
if task == Task.MULTI_SEQUENCE: lower, upper = camera_difficulty_bin_breaks
# TODO: extract those to constants diff_bin_edges = torch.tensor([0.0 - _eps, lower, upper, 1.0 + _eps]).float()
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
diff_bin_edges[0] = 0.0 - _eps
elif task == Task.SINGLE_SEQUENCE:
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
else:
raise ValueError(f"No such eval task {task}.")
diff_bin_names = ["hard", "medium", "easy"] diff_bin_names = ["hard", "medium", "easy"]
return diff_bin_edges, diff_bin_names return diff_bin_edges, diff_bin_names
...@@ -427,6 +421,7 @@ def _get_camera_difficulty_bin_edges(task: Task): ...@@ -427,6 +421,7 @@ def _get_camera_difficulty_bin_edges(task: Task):
def summarize_nvs_eval_results( def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]], per_batch_eval_results: List[Dict[str, Any]],
task: Task, task: Task,
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
): ):
""" """
Compile the per-batch evaluation results `per_batch_eval_results` into Compile the per-batch evaluation results `per_batch_eval_results` into
...@@ -435,6 +430,8 @@ def summarize_nvs_eval_results( ...@@ -435,6 +430,8 @@ def summarize_nvs_eval_results(
Args: Args:
per_batch_eval_results: Metrics of each per-batch evaluation. per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task. task: The type of the new-view synthesis task.
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
Returns: Returns:
nvs_results_flat: A flattened dict of all aggregate metrics. nvs_results_flat: A flattened dict of all aggregate metrics.
...@@ -461,7 +458,9 @@ def summarize_nvs_eval_results( ...@@ -461,7 +458,9 @@ def summarize_nvs_eval_results(
# init the result database dict # init the result database dict
results = [] results = []
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(task) diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(
camera_difficulty_bin_breaks
)
n_diff_edges = diff_bin_edges.numel() n_diff_edges = diff_bin_edges.numel()
# add per set averages # add per set averages
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment