Commit baded432 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

allow skipping inference when running evaluation.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/354

Allow skipping inference when running evaluation.
* `inference_on_dataset_with_checkpointing` works similar to `inference_on_dataset` in d2 but allows skipping the inference step if the evaluator has cached the results.
* If the evaluator has a function `could_skip_process` and returns True, inference will be skipped and only `evaluator. reset()` and `evaluator.evaluate()` are called.

Reviewed By: wat3rBro

Differential Revision: D37213004

fbshipit-source-id: d12cc480589ff04fd8dbb42b22633ab34bc4bf63
parent 6ca4702b
# Copyright (c) Facebook, Inc. and its affiliates.
import logging
import os
from collections import abc
from typing import Any, Iterable, List, Union
import torch
from detectron2.evaluation import (
DatasetEvaluator,
DatasetEvaluators,
inference_on_dataset as inference_on_dataset_d2,
)
from detectron2.utils import comm
from detectron2.utils.file_io import PathManager
logger = logging.getLogger(__name__)
def inference_on_dataset(
model: torch.nn.Module,
data_loader: Iterable,
evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None],
):
"""
A drop-in replacement for d2's inference_on_dataset to run inference on datasets,
supports customization for checkpointing
* has_finished_process(self) -> bool: return True if `self.process()` could be skipped
"""
if evaluator is None:
return inference_on_dataset_d2(model, data_loader, evaluator)
if isinstance(evaluator, abc.MutableSequence):
evaluator = DatasetEvaluators(evaluator)
if not (
hasattr(evaluator, "has_finished_process") and evaluator.has_finished_process()
):
return inference_on_dataset_d2(model, data_loader, evaluator)
evaluator.reset()
results = evaluator.evaluate()
if results is None:
results = {}
return results
class ResultCache(object):
def __init__(self, cache_dir: str):
"""A utility class to handle save/load cache data across processes"""
self.cache_str = cache_dir
@property
def cache_file(self):
if self.cache_str is None:
return None
return os.path.join(self.cache_str, f"_result_cache_.{comm.get_rank()}.pkl")
def has_cache(self):
return PathManager.isfile(self.cache_file)
def load(self, gather: bool = False):
"""
Load cache results.
gather (bool): gather cache results arcoss ranks to a list
"""
if self.cache_file is None or not PathManager.exists(self.cache_file):
return None
with PathManager.open(self.cache_file, "rb") as fp:
ret = torch.load(fp)
logger.info(f"Loaded from checkpoint {self.cache_file}")
if gather:
ret = comm.all_gather(ret)
return ret
def save(self, data: Any):
if self.cache_file is None:
return
PathManager.mkdirs(os.path.dirname(self.cache_file))
with PathManager.open(self.cache_file, "wb") as fp:
torch.save(data, fp)
logger.info(f"Saved checkpoint to {self.cache_file}")
#!/usr/bin/env python3
import os
import tempfile
import unittest
from collections import defaultdict
import torch
from d2go.evaluation.evaluator import inference_on_dataset, ResultCache
from detectron2.evaluation import DatasetEvaluator
class EvaluatorForTest(DatasetEvaluator):
def __init__(self):
self.results = []
def reset(self):
self.results.clear()
def process(self, inputs, outputs):
self.results.append(outputs)
def evaluate(self):
return sum(self.results)
class EvaluatorWithCheckpointForTest(DatasetEvaluator):
def __init__(self, save_dir):
self.results = []
self.result_cache = ResultCache(save_dir)
self._call_count = defaultdict(int)
def reset(self):
self.results.clear()
self._call_count["reset"] += 1
def has_finished_process(self):
return self.result_cache.has_cache()
def process(self, inputs, outputs):
assert not self.result_cache.has_cache()
self.results.append(outputs)
self._call_count["process"] += 1
def evaluate(self):
if not self.result_cache.has_cache():
self.result_cache.save(self.results)
else:
self.results = self.result_cache.load()
self._call_count["evaluate"] += 1
return sum(self.results)
class Model(torch.nn.Module):
def forward(self, x):
return x
class TestEvaluator(unittest.TestCase):
def test_inference(self):
model = Model()
evaluator = EvaluatorForTest()
data_loader = [1, 2, 3, 4, 5]
results = inference_on_dataset(model, data_loader, evaluator)
self.assertEqual(results, 15)
def test_inference_with_checkpoint(self):
with tempfile.TemporaryDirectory() as save_dir:
model = Model()
evaluator = EvaluatorWithCheckpointForTest(save_dir)
self.assertFalse(evaluator.has_finished_process())
data_loader = [1, 2, 3, 4, 5]
results = inference_on_dataset(model, data_loader, evaluator)
self.assertEqual(results, 15)
self.assertEqual(evaluator._call_count["reset"], 1)
self.assertEqual(evaluator._call_count["process"], 5)
self.assertEqual(evaluator._call_count["evaluate"], 1)
# run again with cache
self.assertTrue(evaluator.has_finished_process())
results = inference_on_dataset(model, data_loader, evaluator)
self.assertEqual(results, 15)
self.assertEqual(evaluator._call_count["reset"], 2)
self.assertEqual(evaluator._call_count["process"], 5)
self.assertEqual(evaluator._call_count["evaluate"], 2)
self.assertTrue(os.path.isfile(evaluator.result_cache.cache_file))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment