Commit d0e16684 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add callbacks for inference_on_dataset

Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/637

Reviewed By: tglik

Differential Revision: D51540498

fbshipit-source-id: f246559963c5187140db7b8113765f66a964ae1b
parent 87649f4f
......@@ -36,6 +36,7 @@ def inference_on_dataset(
model: torch.nn.Module,
data_loader: Iterable,
evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None],
**kwargs,
):
"""
A drop-in replacement for d2's inference_on_dataset to run inference on datasets,
......@@ -43,7 +44,7 @@ def inference_on_dataset(
* has_finished_process(self) -> bool: return True if `self.process()` could be skipped
"""
if evaluator is None:
return inference_on_dataset_d2(model, data_loader, evaluator)
return inference_on_dataset_d2(model, data_loader, evaluator, **kwargs)
if isinstance(evaluator, abc.MutableSequence):
evaluator = DatasetEvaluators(evaluator)
......@@ -51,7 +52,7 @@ def inference_on_dataset(
if not (
hasattr(evaluator, "has_finished_process") and evaluator.has_finished_process()
):
return inference_on_dataset_d2(model, data_loader, evaluator)
return inference_on_dataset_d2(model, data_loader, evaluator, **kwargs)
evaluator.reset()
results = evaluator.evaluate()
......
......@@ -386,6 +386,16 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
)
return evaluator
# experimental API
@classmethod
def _get_inference_callbacks(cls):
return {
"on_start": lambda: None,
"on_end": lambda: None,
"before_inference": lambda: None,
"after_inference": lambda: None,
}
def _do_test(self, cfg, model, train_iter=None, model_tag="default"):
"""train_iter: Current iteration of the model, None means final iteration"""
assert len(cfg.DATASETS.TEST)
......@@ -430,7 +440,10 @@ class Detectron2GoRunner(D2GoDataAPIMixIn, BaseRunner):
else model,
)
results_per_dataset = inference_on_dataset(model, data_loader, evaluator)
inference_callbacks = self._get_inference_callbacks()
results_per_dataset = inference_on_dataset(
model, data_loader, evaluator, callbacks=inference_callbacks
)
if comm.is_main_process():
results[model_tag][dataset_name] = results_per_dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment