Commit f8bf5280 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

remove get_task

Summary: Remove the dataset's need to provide the task type.

Reviewed By: davnov134, kjchalup

Differential Revision: D38314000

fbshipit-source-id: 3805d885b5d4528abdc78c0da03247edb9abf3f7
parent 37250a43
...@@ -35,3 +35,4 @@ training_loop_ImplicitronTrainingLoop_args: ...@@ -35,3 +35,4 @@ training_loop_ImplicitronTrainingLoop_args:
camera_difficulty_bin_breaks: camera_difficulty_bin_breaks:
- 0.666667 - 0.666667
- 0.833334 - 0.833334
is_multisequence: true
...@@ -206,7 +206,6 @@ class Experiment(Configurable): # pyre-ignore: 13 ...@@ -206,7 +206,6 @@ class Experiment(Configurable): # pyre-ignore: 13
val_loader, val_loader,
) = accelerator.prepare(model, optimizer, train_loader, val_loader) ) = accelerator.prepare(model, optimizer, train_loader, val_loader)
task = self.data_source.get_task()
all_train_cameras = self.data_source.all_train_cameras all_train_cameras = self.data_source.all_train_cameras
# Enter the main training loop. # Enter the main training loop.
...@@ -223,7 +222,6 @@ class Experiment(Configurable): # pyre-ignore: 13 ...@@ -223,7 +222,6 @@ class Experiment(Configurable): # pyre-ignore: 13
exp_dir=self.exp_dir, exp_dir=self.exp_dir,
stats=stats, stats=stats,
seed=self.seed, seed=self.seed,
task=task,
) )
def _check_config_consistent(self) -> None: def _check_config_consistent(self) -> None:
......
...@@ -10,7 +10,6 @@ from typing import Any, Optional ...@@ -10,7 +10,6 @@ from typing import Any, Optional
import torch import torch
from accelerate import Accelerator from accelerate import Accelerator
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import EvaluationMode from pytorch3d.implicitron.models.generic_model import EvaluationMode
...@@ -101,7 +100,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13] ...@@ -101,7 +100,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
exp_dir: str, exp_dir: str,
stats: Stats, stats: Stats,
seed: int, seed: int,
task: Task,
**kwargs, **kwargs,
): ):
""" """
...@@ -123,7 +121,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13] ...@@ -123,7 +121,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
epoch=stats.epoch, epoch=stats.epoch,
exp_dir=exp_dir, exp_dir=exp_dir,
model=model, model=model,
task=task,
) )
return return
else: else:
...@@ -179,7 +176,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13] ...@@ -179,7 +176,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
device=device, device=device,
dataloader=test_loader, dataloader=test_loader,
model=model, model=model,
task=task,
) )
assert stats.epoch == epoch, "inconsistent stats!" assert stats.epoch == epoch, "inconsistent stats!"
...@@ -200,7 +196,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13] ...@@ -200,7 +196,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
exp_dir=exp_dir, exp_dir=exp_dir,
dataloader=test_loader, dataloader=test_loader,
model=model, model=model,
task=task,
) )
else: else:
raise ValueError( raise ValueError(
......
...@@ -435,3 +435,4 @@ training_loop_ImplicitronTrainingLoop_args: ...@@ -435,3 +435,4 @@ training_loop_ImplicitronTrainingLoop_args:
camera_difficulty_bin_breaks: camera_difficulty_bin_breaks:
- 0.97 - 0.97
- 0.98 - 0.98
is_multisequence: false
...@@ -15,7 +15,7 @@ from pytorch3d.renderer.cameras import CamerasBase ...@@ -15,7 +15,7 @@ from pytorch3d.renderer.cameras import CamerasBase
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
...@@ -41,9 +41,6 @@ class DataSourceBase(ReplaceableBase): ...@@ -41,9 +41,6 @@ class DataSourceBase(ReplaceableBase):
""" """
raise NotImplementedError() raise NotImplementedError()
def get_task(self) -> Task:
raise NotImplementedError()
@registry.register @registry.register
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
...@@ -71,9 +68,6 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] ...@@ -71,9 +68,6 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets) dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
return datasets, dataloaders return datasets, dataloaders
def get_task(self) -> Task:
return self.dataset_map_provider.get_task()
@property @property
def all_train_cameras(self) -> Optional[CamerasBase]: def all_train_cameras(self) -> Optional[CamerasBase]:
if self._all_train_cameras_cache is None: # pyre-ignore[16] if self._all_train_cameras_cache is None: # pyre-ignore[16]
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
import logging import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import Iterator, Optional from typing import Iterator, Optional
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
...@@ -53,11 +52,6 @@ class DatasetMap: ...@@ -53,11 +52,6 @@ class DatasetMap:
yield self.test yield self.test
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
class DatasetMapProviderBase(ReplaceableBase): class DatasetMapProviderBase(ReplaceableBase):
""" """
Base class for a provider of training / validation and testing Base class for a provider of training / validation and testing
...@@ -71,9 +65,6 @@ class DatasetMapProviderBase(ReplaceableBase): ...@@ -71,9 +65,6 @@ class DatasetMapProviderBase(ReplaceableBase):
""" """
raise NotImplementedError() raise NotImplementedError()
def get_task(self) -> Task:
raise NotImplementedError()
def get_all_train_cameras(self) -> Optional[CamerasBase]: def get_all_train_cameras(self) -> Optional[CamerasBase]:
""" """
If the data is all for a single scene, returns a list If the data is all for a single scene, returns a list
......
...@@ -17,12 +17,7 @@ from pytorch3d.implicitron.tools.config import ( ...@@ -17,12 +17,7 @@ from pytorch3d.implicitron.tools.config import (
) )
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from .dataset_map_provider import ( from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .json_index_dataset import JsonIndexDataset from .json_index_dataset import JsonIndexDataset
from .utils import ( from .utils import (
...@@ -160,7 +155,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] ...@@ -160,7 +155,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
# This maps the common names of the dataset subsets ("train"/"val"/"test") # This maps the common names of the dataset subsets ("train"/"val"/"test")
# to the names of the subsets in the CO3D dataset. # to the names of the subsets in the CO3D dataset.
set_names_mapping = _get_co3d_set_names_mapping( set_names_mapping = _get_co3d_set_names_mapping(
self.get_task(), self.task_str,
self.test_on_train, self.test_on_train,
self.only_test_set, self.only_test_set,
) )
...@@ -185,7 +180,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] ...@@ -185,7 +180,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
eval_batch_index = json.load(f) eval_batch_index = json.load(f)
restrict_sequence_name = self.restrict_sequence_name restrict_sequence_name = self.restrict_sequence_name
if self.get_task() == Task.SINGLE_SEQUENCE: if self.task_str == "singlesequence":
if ( if (
self.test_restrict_sequence_id is None self.test_restrict_sequence_id is None
or self.test_restrict_sequence_id < 0 or self.test_restrict_sequence_id < 0
...@@ -267,13 +262,12 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] ...@@ -267,13 +262,12 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
# pyre-ignore[16] # pyre-ignore[16]
return self.dataset_map return self.dataset_map
def get_task(self) -> Task:
return Task(self.task_str)
def get_all_train_cameras(self) -> Optional[CamerasBase]: def get_all_train_cameras(self) -> Optional[CamerasBase]:
if Task(self.task_str) == Task.MULTI_SEQUENCE: if self.task_str == "multisequence":
return None return None
assert self.task_str == "singlesequence"
# pyre-ignore[16] # pyre-ignore[16]
train_dataset = self.dataset_map.train train_dataset = self.dataset_map.train
assert isinstance(train_dataset, JsonIndexDataset) assert isinstance(train_dataset, JsonIndexDataset)
...@@ -281,7 +275,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] ...@@ -281,7 +275,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
def _get_co3d_set_names_mapping( def _get_co3d_set_names_mapping(
task: Task, task_str: str,
test_on_train: bool, test_on_train: bool,
only_test: bool, only_test: bool,
) -> Dict[str, List[str]]: ) -> Dict[str, List[str]]:
...@@ -295,7 +289,7 @@ def _get_co3d_set_names_mapping( ...@@ -295,7 +289,7 @@ def _get_co3d_set_names_mapping(
- val (if not test_on_train) - val (if not test_on_train)
- test (if not test_on_train) - test (if not test_on_train)
""" """
single_seq = task == Task.SINGLE_SEQUENCE single_seq = task_str == "singlesequence"
if only_test: if only_test:
set_names_mapping = {} set_names_mapping = {}
......
...@@ -16,7 +16,6 @@ from pytorch3d.implicitron.dataset.dataset_map_provider import ( ...@@ -16,7 +16,6 @@ from pytorch3d.implicitron.dataset.dataset_map_provider import (
DatasetMap, DatasetMap,
DatasetMapProviderBase, DatasetMapProviderBase,
PathManagerFactory, PathManagerFactory,
Task,
) )
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
...@@ -335,12 +334,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] ...@@ -335,12 +334,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
) )
return category_to_subset_name_list return category_to_subset_name_list
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
return {
"manyview": Task.SINGLE_SEQUENCE,
"fewview": Task.MULTI_SEQUENCE,
}[self.subset_name.split("_")[0]]
def get_all_train_cameras(self) -> Optional[CamerasBase]: def get_all_train_cameras(self) -> Optional[CamerasBase]:
# pyre-ignore[16] # pyre-ignore[16]
train_dataset = self.dataset_map.train train_dataset = self.dataset_map.train
......
...@@ -28,12 +28,7 @@ from pytorch3d.renderer import ( ...@@ -28,12 +28,7 @@ from pytorch3d.renderer import (
) )
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from .dataset_map_provider import ( from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .single_sequence_dataset import SingleSceneDataset from .single_sequence_dataset import SingleSceneDataset
from .utils import DATASET_TYPE_KNOWN from .utils import DATASET_TYPE_KNOWN
...@@ -83,9 +78,6 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13 ...@@ -83,9 +78,6 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
# pyre-ignore[16] # pyre-ignore[16]
return DatasetMap(train=self.train_dataset, val=None, test=None) return DatasetMap(train=self.train_dataset, val=None, test=None)
def get_task(self) -> Task:
return Task.SINGLE_SEQUENCE
def get_all_train_cameras(self) -> CamerasBase: def get_all_train_cameras(self) -> CamerasBase:
# pyre-ignore[16] # pyre-ignore[16]
return self.poses return self.poses
......
...@@ -21,12 +21,7 @@ from pytorch3d.implicitron.tools.config import ( ...@@ -21,12 +21,7 @@ from pytorch3d.implicitron.tools.config import (
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
from .dataset_base import DatasetBase, FrameData from .dataset_base import DatasetBase, FrameData
from .dataset_map_provider import ( from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
_SINGLE_SEQUENCE_NAME: str = "one_sequence" _SINGLE_SEQUENCE_NAME: str = "one_sequence"
...@@ -159,9 +154,6 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase): ...@@ -159,9 +154,6 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True), test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
) )
def get_task(self) -> Task:
return Task.SINGLE_SEQUENCE
def get_all_train_cameras(self) -> Optional[CamerasBase]: def get_all_train_cameras(self) -> Optional[CamerasBase]:
# pyre-ignore[16] # pyre-ignore[16]
cameras = [self.poses[i] for i in self.i_split[0]] cameras = [self.poses[i] for i in self.i_split[0]]
......
...@@ -7,11 +7,12 @@ ...@@ -7,11 +7,12 @@
import dataclasses import dataclasses
import os import os
from enum import Enum
from typing import Any, cast, Dict, List, Optional, Tuple from typing import Any, cast, Dict, List, Optional, Tuple
import lpips import lpips
import torch import torch
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import ( from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
CO3D_CATEGORIES, CO3D_CATEGORIES,
...@@ -27,6 +28,11 @@ from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_ ...@@ -27,6 +28,11 @@ from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
from tqdm import tqdm from tqdm import tqdm
class Task(Enum):
SINGLE_SEQUENCE = "singlesequence"
MULTI_SEQUENCE = "multisequence"
def main() -> None: def main() -> None:
""" """
Evaluates new view synthesis metrics of a simple depth-based image rendering Evaluates new view synthesis metrics of a simple depth-based image rendering
...@@ -153,11 +159,15 @@ def evaluate_dbir_for_category( ...@@ -153,11 +159,15 @@ def evaluate_dbir_for_category(
if task == Task.SINGLE_SEQUENCE: if task == Task.SINGLE_SEQUENCE:
camera_difficulty_bin_breaks = 0.97, 0.98 camera_difficulty_bin_breaks = 0.97, 0.98
multisequence_evaluation = False
else: else:
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6 camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
multisequence_evaluation = True
category_result_flat, category_result = summarize_nvs_eval_results( category_result_flat, category_result = summarize_nvs_eval_results(
per_batch_eval_results, task, camera_difficulty_bin_breaks per_batch_eval_results,
camera_difficulty_bin_breaks=camera_difficulty_bin_breaks,
is_multisequence=multisequence_evaluation,
) )
return category_result["results"] return category_result["results"]
......
...@@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union ...@@ -14,7 +14,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
from pytorch3d.implicitron.models.base_model import ImplicitronRender from pytorch3d.implicitron.models.base_model import ImplicitronRender
...@@ -420,16 +419,16 @@ def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float, ...@@ -420,16 +419,16 @@ def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float,
def summarize_nvs_eval_results( def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]], per_batch_eval_results: List[Dict[str, Any]],
task: Task, is_multisequence: bool,
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98), camera_difficulty_bin_breaks: Tuple[float, float],
): ):
""" """
Compile the per-batch evaluation results `per_batch_eval_results` into Compile the per-batch evaluation results `per_batch_eval_results` into
a set of aggregate metrics. The produced metrics depend on the task. a set of aggregate metrics. The produced metrics depend on is_multisequence.
Args: Args:
per_batch_eval_results: Metrics of each per-batch evaluation. per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task. is_multisequence: Whether to evaluate as a multisequence task
camera_difficulty_bin_breaks: edge hard-medium and medium-easy camera_difficulty_bin_breaks: edge hard-medium and medium-easy
...@@ -439,14 +438,9 @@ def summarize_nvs_eval_results( ...@@ -439,14 +438,9 @@ def summarize_nvs_eval_results(
""" """
n_batches = len(per_batch_eval_results) n_batches = len(per_batch_eval_results)
eval_sets: List[Optional[str]] = [] eval_sets: List[Optional[str]] = []
if task == Task.SINGLE_SEQUENCE: eval_sets = [None]
eval_sets = [None] if is_multisequence:
# assert n_batches==100
elif task == Task.MULTI_SEQUENCE:
eval_sets = ["train", "test"] eval_sets = ["train", "test"]
# assert n_batches==1000
else:
raise ValueError(task)
batch_sizes = torch.tensor( batch_sizes = torch.tensor(
[r["meta"]["batch_size"] for r in per_batch_eval_results] [r["meta"]["batch_size"] for r in per_batch_eval_results]
).long() ).long()
...@@ -466,11 +460,9 @@ def summarize_nvs_eval_results( ...@@ -466,11 +460,9 @@ def summarize_nvs_eval_results(
# add per set averages # add per set averages
for SET in eval_sets: for SET in eval_sets:
if SET is None: if SET is None:
assert task == Task.SINGLE_SEQUENCE
ok_set = torch.ones(n_batches, dtype=torch.bool) ok_set = torch.ones(n_batches, dtype=torch.bool)
set_name = "test" set_name = "test"
else: else:
assert task == Task.MULTI_SEQUENCE
ok_set = is_train == int(SET == "train") ok_set = is_train == int(SET == "train")
set_name = SET set_name = SET
...@@ -495,7 +487,7 @@ def summarize_nvs_eval_results( ...@@ -495,7 +487,7 @@ def summarize_nvs_eval_results(
} }
) )
if task == Task.MULTI_SEQUENCE: if is_multisequence:
# split based on n_src_views # split based on n_src_views
n_src_views = batch_sizes - 1 n_src_views = batch_sizes - 1
for n_src in EVAL_N_SRC_VIEWS: for n_src in EVAL_N_SRC_VIEWS:
......
...@@ -16,7 +16,6 @@ import torch ...@@ -16,7 +16,6 @@ import torch
import tqdm import tqdm
from pytorch3d.implicitron.dataset import utils as ds_utils from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
...@@ -57,6 +56,7 @@ class ImplicitronEvaluator(EvaluatorBase): ...@@ -57,6 +56,7 @@ class ImplicitronEvaluator(EvaluatorBase):
""" """
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98 camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
is_multisequence: bool = False
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
...@@ -65,7 +65,6 @@ class ImplicitronEvaluator(EvaluatorBase): ...@@ -65,7 +65,6 @@ class ImplicitronEvaluator(EvaluatorBase):
self, self,
model: ImplicitronModelBase, model: ImplicitronModelBase,
dataloader: DataLoader, dataloader: DataLoader,
task: Task,
all_train_cameras: Optional[CamerasBase], all_train_cameras: Optional[CamerasBase],
device: torch.device, device: torch.device,
dump_to_json: bool = False, dump_to_json: bool = False,
...@@ -80,7 +79,6 @@ class ImplicitronEvaluator(EvaluatorBase): ...@@ -80,7 +79,6 @@ class ImplicitronEvaluator(EvaluatorBase):
Args: Args:
model: A (trained) model to evaluate. model: A (trained) model to evaluate.
dataloader: A test dataloader. dataloader: A test dataloader.
task: Type of the novel-view synthesis task we're working on.
all_train_cameras: Camera instances we used for training. all_train_cameras: Camera instances we used for training.
device: A torch device. device: A torch device.
dump_to_json: If True, will dump the results to a json file. dump_to_json: If True, will dump the results to a json file.
...@@ -122,7 +120,9 @@ class ImplicitronEvaluator(EvaluatorBase): ...@@ -122,7 +120,9 @@ class ImplicitronEvaluator(EvaluatorBase):
) )
_, category_result = evaluate.summarize_nvs_eval_results( _, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task, self.camera_difficulty_bin_breaks per_batch_eval_results,
self.is_multisequence,
self.camera_difficulty_bin_breaks,
) )
results = category_result["results"] results = category_result["results"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment