"docs/gpu.mdx" did not exist on "021817e59ace5e351b35b2e6881f83a09f038546"
test_evaluator.py 2.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python3

import os
import tempfile
import unittest
from collections import defaultdict

import torch
from d2go.evaluation.evaluator import inference_on_dataset, ResultCache
from detectron2.evaluation import DatasetEvaluator


class EvaluatorForTest(DatasetEvaluator):
    def __init__(self):
        self.results = []

    def reset(self):
        self.results.clear()

    def process(self, inputs, outputs):
        self.results.append(outputs)

    def evaluate(self):
        return sum(self.results)


class EvaluatorWithCheckpointForTest(DatasetEvaluator):
    def __init__(self, save_dir):
        self.results = []
        self.result_cache = ResultCache(save_dir)
        self._call_count = defaultdict(int)

    def reset(self):
        self.results.clear()
        self._call_count["reset"] += 1

    def has_finished_process(self):
        return self.result_cache.has_cache()

    def process(self, inputs, outputs):
        assert not self.result_cache.has_cache()
        self.results.append(outputs)
        self._call_count["process"] += 1

    def evaluate(self):
        if not self.result_cache.has_cache():
            self.result_cache.save(self.results)
        else:
            self.results = self.result_cache.load()
        self._call_count["evaluate"] += 1

        return sum(self.results)


class Model(torch.nn.Module):
    def forward(self, x):
        return x


class TestEvaluator(unittest.TestCase):
    def test_inference(self):
        model = Model()
        evaluator = EvaluatorForTest()
        data_loader = [1, 2, 3, 4, 5]
        results = inference_on_dataset(model, data_loader, evaluator)
        self.assertEqual(results, 15)

    def test_inference_with_checkpoint(self):
        with tempfile.TemporaryDirectory() as save_dir:
            model = Model()
            evaluator = EvaluatorWithCheckpointForTest(save_dir)
            self.assertFalse(evaluator.has_finished_process())
            data_loader = [1, 2, 3, 4, 5]
            results = inference_on_dataset(model, data_loader, evaluator)
            self.assertEqual(results, 15)
            self.assertEqual(evaluator._call_count["reset"], 1)
            self.assertEqual(evaluator._call_count["process"], 5)
            self.assertEqual(evaluator._call_count["evaluate"], 1)

            # run again with cache
            self.assertTrue(evaluator.has_finished_process())
            results = inference_on_dataset(model, data_loader, evaluator)
            self.assertEqual(results, 15)
            self.assertEqual(evaluator._call_count["reset"], 2)
            self.assertEqual(evaluator._call_count["process"], 5)
            self.assertEqual(evaluator._call_count["evaluate"], 2)
            self.assertTrue(os.path.isfile(evaluator.result_cache.cache_file))