Commit 3783437d authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

lazy all_train_cameras

Summary: Avoid calculating all_train_cameras before it is needed, because it is slow in some datasets.

Reviewed By: shapovalov

Differential Revision: D38037157

fbshipit-source-id: 95461226655cde2626b680661951ab17ebb0ec75
parent b2dc5202
...@@ -391,7 +391,6 @@ def run_training(cfg: DictConfig) -> None: ...@@ -391,7 +391,6 @@ def run_training(cfg: DictConfig) -> None:
datasource = ImplicitronDataSource(**cfg.data_source_args) datasource = ImplicitronDataSource(**cfg.data_source_args)
datasets, dataloaders = datasource.get_datasets_and_dataloaders() datasets, dataloaders = datasource.get_datasets_and_dataloaders()
task = datasource.get_task() task = datasource.get_task()
all_train_cameras = datasource.get_all_train_cameras()
# init the model # init the model
model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator) model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator)
...@@ -405,7 +404,7 @@ def run_training(cfg: DictConfig) -> None: ...@@ -405,7 +404,7 @@ def run_training(cfg: DictConfig) -> None:
_eval_and_dump( _eval_and_dump(
cfg, cfg,
task, task,
all_train_cameras, datasource.all_train_cameras,
datasets, datasets,
dataloaders, dataloaders,
model, model,
...@@ -490,7 +489,7 @@ def run_training(cfg: DictConfig) -> None: ...@@ -490,7 +489,7 @@ def run_training(cfg: DictConfig) -> None:
): ):
_run_eval( _run_eval(
model, model,
all_train_cameras, datasource.all_train_cameras,
dataloaders.test, dataloaders.test,
task, task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks, camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
...@@ -525,7 +524,7 @@ def run_training(cfg: DictConfig) -> None: ...@@ -525,7 +524,7 @@ def run_training(cfg: DictConfig) -> None:
_eval_and_dump( _eval_and_dump(
cfg, cfg,
task, task,
all_train_cameras, datasource.all_train_cameras,
datasets, datasets,
dataloaders, dataloaders,
model, model,
......
...@@ -30,9 +30,10 @@ class DataSourceBase(ReplaceableBase): ...@@ -30,9 +30,10 @@ class DataSourceBase(ReplaceableBase):
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]: def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
raise NotImplementedError() raise NotImplementedError()
def get_all_train_cameras(self) -> Optional[CamerasBase]: @property
def all_train_cameras(self) -> Optional[CamerasBase]:
""" """
If the data is all for a single scene, returns a list If the data is all for a single scene, a list
of the known training cameras for that scene, which is of the known training cameras for that scene, which is
used for evaluating the viewpoint difficulty of the used for evaluating the viewpoint difficulty of the
unseen cameras. unseen cameras.
...@@ -59,6 +60,7 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] ...@@ -59,6 +60,7 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]: def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
datasets = self.dataset_map_provider.get_dataset_map() datasets = self.dataset_map_provider.get_dataset_map()
...@@ -68,5 +70,10 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] ...@@ -68,5 +70,10 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
def get_task(self) -> Task: def get_task(self) -> Task:
return self.dataset_map_provider.get_task() return self.dataset_map_provider.get_task()
def get_all_train_cameras(self) -> Optional[CamerasBase]: @property
return self.dataset_map_provider.get_all_train_cameras() def all_train_cameras(self) -> Optional[CamerasBase]:
if self._all_train_cameras_cache is None: # pyre-ignore[16]
all_train_cameras = self.dataset_map_provider.get_all_train_cameras()
self._all_train_cameras_cache = (all_train_cameras,)
return self._all_train_cameras_cache[0]
...@@ -118,8 +118,6 @@ def evaluate_dbir_for_category( ...@@ -118,8 +118,6 @@ def evaluate_dbir_for_category(
if test_dataset is None or test_dataloader is None: if test_dataset is None or test_dataloader is None:
raise ValueError("must have a test dataset.") raise ValueError("must have a test dataset.")
all_train_cameras = data_source.get_all_train_cameras()
image_size = cast(JsonIndexDataset, test_dataset).image_width image_size = cast(JsonIndexDataset, test_dataset).image_width
if image_size is None: if image_size is None:
...@@ -149,7 +147,7 @@ def evaluate_dbir_for_category( ...@@ -149,7 +147,7 @@ def evaluate_dbir_for_category(
preds["implicitron_render"], preds["implicitron_render"],
bg_color=bg_color, bg_color=bg_color,
lpips_model=lpips_model, lpips_model=lpips_model,
source_cameras=all_train_cameras, source_cameras=data_source.all_train_cameras,
) )
) )
......
...@@ -10,6 +10,7 @@ from collections import defaultdict ...@@ -10,6 +10,7 @@ from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from itertools import product from itertools import product
import numpy as np
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( from pytorch3d.implicitron.dataset.data_loader_map_provider import (
DoublePoolBatchSampler, DoublePoolBatchSampler,
) )
...@@ -53,6 +54,7 @@ class MockDataset(DatasetBase): ...@@ -53,6 +54,7 @@ class MockDataset(DatasetBase):
class TestSceneBatchSampler(unittest.TestCase): class TestSceneBatchSampler(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(42)
self.dataset_overfit = MockDataset(1) self.dataset_overfit = MockDataset(1)
def test_overfit(self): def test_overfit(self):
......
...@@ -31,7 +31,7 @@ class TestDataJsonIndex(TestCaseMixin, unittest.TestCase): ...@@ -31,7 +31,7 @@ class TestDataJsonIndex(TestCaseMixin, unittest.TestCase):
data_source = ImplicitronDataSource(**args) data_source = ImplicitronDataSource(**args)
cameras = data_source.get_all_train_cameras() cameras = data_source.all_train_cameras
self.assertIsInstance(cameras, PerspectiveCameras) self.assertIsInstance(cameras, PerspectiveCameras)
self.assertEqual(len(cameras), 81) self.assertEqual(len(cameras), 81)
......
...@@ -152,6 +152,6 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase): ...@@ -152,6 +152,6 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
self.assertEqual(i.frame_type, ["unseen"]) self.assertEqual(i.frame_type, ["unseen"])
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800)) self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
cameras = data_source.get_all_train_cameras() cameras = data_source.all_train_cameras
self.assertIsInstance(cameras, PerspectiveCameras) self.assertIsInstance(cameras, PerspectiveCameras)
self.assertEqual(len(cameras), 100) self.assertEqual(len(cameras), 100)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment