Commit 923a0568 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

switch to use inference_on_dataset_with_checkpointing in default runner.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/355

switch to use inference_on_dataset_with_checkpointing in default runner.

Reviewed By: HarounH

Differential Revision: D37215292

fbshipit-source-id: c006784ce0b31700bcbb1f79c303fd791f1561ff
parent baded432
......@@ -18,6 +18,20 @@ from detectron2.utils.file_io import PathManager
logger = logging.getLogger(__name__)
def DatasetEvaluators_has_finished_process(self):
ret = True
for x in self._evaluators:
if hasattr(x, "has_finished_process"):
ret &= x.has_finished_process()
else:
ret &= False
return ret
# patch evaluators defined in d2
DatasetEvaluators.has_finished_process = DatasetEvaluators_has_finished_process
def inference_on_dataset(
model: torch.nn.Module,
data_loader: Iterable,
......
......@@ -22,6 +22,7 @@ from d2go.data.utils import (
maybe_subsample_n_images,
update_cfg_if_using_adhoc_dataset,
)
from d2go.evaluation.evaluator import inference_on_dataset
from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.api import build_d2go_model
from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad
......@@ -47,7 +48,6 @@ from detectron2.engine import AMPTrainer, hooks, SimpleTrainer
from detectron2.evaluation import (
COCOEvaluator,
DatasetEvaluators,
inference_on_dataset,
LVISEvaluator,
print_csv_format,
RotatedCOCOEvaluator,
......
......@@ -276,3 +276,6 @@ class VisualizationEvaluator(DatasetEvaluator):
self._log_remaining -= 1
self._iter += 1
def has_finished_process(self):
return True
......@@ -7,7 +7,7 @@ from collections import defaultdict
import torch
from d2go.evaluation.evaluator import inference_on_dataset, ResultCache
from detectron2.evaluation import DatasetEvaluator
from detectron2.evaluation import DatasetEvaluator, DatasetEvaluators
class EvaluatorForTest(DatasetEvaluator):
......@@ -85,3 +85,15 @@ class TestEvaluator(unittest.TestCase):
self.assertEqual(evaluator._call_count["process"], 5)
self.assertEqual(evaluator._call_count["evaluate"], 2)
self.assertTrue(os.path.isfile(evaluator.result_cache.cache_file))
def test_evaluators_patch(self):
with tempfile.TemporaryDirectory() as save_dir:
cp_evaluator = EvaluatorWithCheckpointForTest(save_dir)
evaluator = DatasetEvaluators([cp_evaluator])
self.assertFalse(evaluator.has_finished_process())
cp_evaluator.reset()
cp_evaluator.process(1, 1)
cp_evaluator.evaluate()
self.assertTrue(evaluator.has_finished_process())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment