"vscode:/vscode.git/clone" did not exist on "4b8eaf205b0a73cd040e7917c6b602e456feda4e"
Commit 923a0568 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

switch to use inference_on_dataset_with_checkpointing in default runner.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/355

switch to use inference_on_dataset_with_checkpointing in default runner.

Reviewed By: HarounH

Differential Revision: D37215292

fbshipit-source-id: c006784ce0b31700bcbb1f79c303fd791f1561ff
parent baded432
...@@ -18,6 +18,20 @@ from detectron2.utils.file_io import PathManager ...@@ -18,6 +18,20 @@ from detectron2.utils.file_io import PathManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def DatasetEvaluators_has_finished_process(self):
ret = True
for x in self._evaluators:
if hasattr(x, "has_finished_process"):
ret &= x.has_finished_process()
else:
ret &= False
return ret
# patch evaluators defined in d2
DatasetEvaluators.has_finished_process = DatasetEvaluators_has_finished_process
def inference_on_dataset( def inference_on_dataset(
model: torch.nn.Module, model: torch.nn.Module,
data_loader: Iterable, data_loader: Iterable,
......
...@@ -22,6 +22,7 @@ from d2go.data.utils import ( ...@@ -22,6 +22,7 @@ from d2go.data.utils import (
maybe_subsample_n_images, maybe_subsample_n_images,
update_cfg_if_using_adhoc_dataset, update_cfg_if_using_adhoc_dataset,
) )
from d2go.evaluation.evaluator import inference_on_dataset
from d2go.modeling import kmeans_anchors, model_ema from d2go.modeling import kmeans_anchors, model_ema
from d2go.modeling.api import build_d2go_model from d2go.modeling.api import build_d2go_model
from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad from d2go.modeling.model_freezing_utils import freeze_matched_bn, set_requires_grad
...@@ -47,7 +48,6 @@ from detectron2.engine import AMPTrainer, hooks, SimpleTrainer ...@@ -47,7 +48,6 @@ from detectron2.engine import AMPTrainer, hooks, SimpleTrainer
from detectron2.evaluation import ( from detectron2.evaluation import (
COCOEvaluator, COCOEvaluator,
DatasetEvaluators, DatasetEvaluators,
inference_on_dataset,
LVISEvaluator, LVISEvaluator,
print_csv_format, print_csv_format,
RotatedCOCOEvaluator, RotatedCOCOEvaluator,
......
...@@ -276,3 +276,6 @@ class VisualizationEvaluator(DatasetEvaluator): ...@@ -276,3 +276,6 @@ class VisualizationEvaluator(DatasetEvaluator):
self._log_remaining -= 1 self._log_remaining -= 1
self._iter += 1 self._iter += 1
def has_finished_process(self):
return True
...@@ -7,7 +7,7 @@ from collections import defaultdict ...@@ -7,7 +7,7 @@ from collections import defaultdict
import torch import torch
from d2go.evaluation.evaluator import inference_on_dataset, ResultCache from d2go.evaluation.evaluator import inference_on_dataset, ResultCache
from detectron2.evaluation import DatasetEvaluator from detectron2.evaluation import DatasetEvaluator, DatasetEvaluators
class EvaluatorForTest(DatasetEvaluator): class EvaluatorForTest(DatasetEvaluator):
...@@ -85,3 +85,15 @@ class TestEvaluator(unittest.TestCase): ...@@ -85,3 +85,15 @@ class TestEvaluator(unittest.TestCase):
self.assertEqual(evaluator._call_count["process"], 5) self.assertEqual(evaluator._call_count["process"], 5)
self.assertEqual(evaluator._call_count["evaluate"], 2) self.assertEqual(evaluator._call_count["evaluate"], 2)
self.assertTrue(os.path.isfile(evaluator.result_cache.cache_file)) self.assertTrue(os.path.isfile(evaluator.result_cache.cache_file))
def test_evaluators_patch(self):
with tempfile.TemporaryDirectory() as save_dir:
cp_evaluator = EvaluatorWithCheckpointForTest(save_dir)
evaluator = DatasetEvaluators([cp_evaluator])
self.assertFalse(evaluator.has_finished_process())
cp_evaluator.reset()
cp_evaluator.process(1, 1)
cp_evaluator.evaluate()
self.assertTrue(evaluator.has_finished_process())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment