Commit d0e16684 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add callbacks for inference_on_dataset

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/637

Reviewed By: tglik

Differential Revision: D51540498

fbshipit-source-id: f246559963c5187140db7b8113765f66a964ae1b
parent 87649f4f
...@@ -36,6 +36,7 @@ def inference_on_dataset( ...@@ -36,6 +36,7 @@ def inference_on_dataset(
model: torch.nn.Module, model: torch.nn.Module,
data_loader: Iterable, data_loader: Iterable,
evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None], evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None],
**kwargs,
): ):
""" """
A drop-in replacement for d2's inference_on_dataset to run inference on datasets, A drop-in replacement for d2's inference_on_dataset to run inference on datasets,
...@@ -43,7 +44,7 @@ def inference_on_dataset( ...@@ -43,7 +44,7 @@ def inference_on_dataset(
* has_finished_process(self) -> bool: return True if `self.process()` could be skipped * has_finished_process(self) -> bool: return True if `self.process()` could be skipped
""" """
if evaluator is None: if evaluator is None:
return inference_on_dataset_d2(model, data_loader, evaluator) return inference_on_dataset_d2(model, data_loader, evaluator, **kwargs)
if isinstance(evaluator, abc.MutableSequence): if isinstance(evaluator, abc.MutableSequence):
evaluator = DatasetEvaluators(evaluator) evaluator = DatasetEvaluators(evaluator)
...@@ -51,7 +52,7 @@ def inference_on_dataset( ...@@ -51,7 +52,7 @@ def inference_on_dataset(
if not ( if not (
hasattr(evaluator, "has_finished_process") and evaluator.has_finished_process() hasattr(evaluator, "has_finished_process") and evaluator.has_finished_process()
): ):
return inference_on_dataset_d2(model, data_loader, evaluator) return inference_on_dataset_d2(model, data_loader, evaluator, **kwargs)
evaluator.reset() evaluator.reset()
results = evaluator.evaluate() results = evaluator.evaluate()
......
...@@ -386,6 +386,16 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -386,6 +386,16 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
) )
return evaluator return evaluator
# experimental API
@classmethod
def _get_inference_callbacks(cls):
return {
"on_start": lambda: None,
"on_end": lambda: None,
"before_inference": lambda: None,
"after_inference": lambda: None,
}
def _do_test(self, cfg, model, train_iter=None, model_tag="default"): def _do_test(self, cfg, model, train_iter=None, model_tag="default"):
"""train_iter: Current iteration of the model, None means final iteration""" """train_iter: Current iteration of the model, None means final iteration"""
assert len(cfg.DATASETS.TEST) assert len(cfg.DATASETS.TEST)
...@@ -430,7 +440,10 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner): ...@@ -430,7 +440,10 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
else model, else model,
) )
results_per_dataset = inference_on_dataset(model, data_loader, evaluator) inference_callbacks = self._get_inference_callbacks()
results_per_dataset = inference_on_dataset(
model, data_loader, evaluator, callbacks=inference_callbacks
)
if comm.is_main_process(): if comm.is_main_process():
results[model_tag][dataset_name] = results_per_dataset results[model_tag][dataset_name] = results_per_dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment