"lmdeploy/vscode:/vscode.git/clone" did not exist on "8ba2d7c51b22a7c1ee9d641dc1d57a8398d32c67"
test_meta_arch_semantic_seg.py 3.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import os
import shutil
import tempfile
import unittest

import torch
from d2go.export.api import convert_and_export_predictor
from d2go.export.d2_meta_arch import patch_d2_meta_arch
from d2go.runner import Detectron2GoRunner
from mobile_cv.predictor.api import create_predictor

# Add APIs to D2's meta arch, this is usually called in runner's setup, however in
# unittest it needs to be called sperarately. (maybe we should apply this by default)
patch_d2_meta_arch()


def _get_batch(height, width, is_train):
    def _get_frame():
        random_image = torch.rand(3, height, width).to(torch.float32)
        ret = {"image": random_image}
        if is_train:
            mask_size = (height, width)
            random_mask = torch.randint(low=0, high=2, size=mask_size).to(torch.int64)
            ret["sem_seg"] = random_mask
        return ret

    batch_size = 2 if is_train else 1
    return [
        {"filename": "some_file", "width": 100, "height": 100, **_get_frame()}
        for _ in range(batch_size)
    ]


def _get_data_loader(height, width, is_train):
    inputs = _get_batch(height, width, is_train)

    def get_data_loader():
        while True:
            yield inputs

    return get_data_loader()


def _get_input_dim(model):
    h = w = max(model.backbone.size_divisibility, 1)
    return h, w


class BaseSemanticSegTestCase:
    class TemplateTestCase(unittest.TestCase):
        def setUp(self):
            self.test_dir = tempfile.mkdtemp(prefix="test_meta_arch_semantic_seg_")
            self.addCleanup(shutil.rmtree, self.test_dir)

            runner = Detectron2GoRunner()
            self.cfg = runner.get_default_cfg()
            self.setup_custom_test()

            self.cfg.merge_from_list(["MODEL.DEVICE", "cpu"])
            self.test_model = runner.build_model(self.cfg, eval_only=True)

        def setup_custom_test(self):
            raise NotImplementedError()

        def test_inference(self):
            h, w = _get_input_dim(self.test_model)
            inputs = _get_batch(h, w, False)
            with torch.no_grad():
                self.test_model(inputs)

        def test_train(self):
            h, w = _get_input_dim(self.test_model)
            inputs = _get_batch(h, w, True)
            self.test_model.train()
            loss_dict = self.test_model(inputs)
            losses = sum(loss_dict.values())
            losses.backward()

        def _test_export(self, predictor_type, compare_match=True):
            h, w = _get_input_dim(self.test_model)
            dl = _get_data_loader(h, w, False)
            inputs = next(iter(dl))

            output_dir = os.path.join(self.test_dir, "test_export")
            predictor_path = convert_and_export_predictor(
                self.cfg, self.test_model, predictor_type, output_dir, dl
            )

            predictor = create_predictor(predictor_path)
            predicotr_outputs = predictor(inputs)
            self.assertEqual(len(predicotr_outputs), len(inputs))

            with torch.no_grad():
                pytorch_outputs = self.test_model(inputs)
                self.assertEqual(len(pytorch_outputs), len(inputs))

            if compare_match:
                for predictor_output, pytorch_output in zip(
                    predicotr_outputs, pytorch_outputs
                ):
                    torch.testing.assert_allclose(
                        predictor_output["sem_seg"], pytorch_output["sem_seg"]
                    )


class TestR50FPN(BaseSemanticSegTestCase.TemplateTestCase):
    def setup_custom_test(self):
        self.cfg.merge_from_file("detectron2://Misc/semantic_R_50_FPN_1x.yaml")

    def test_export_torchscript(self):
        self._test_export("torchscript", compare_match=True)