test_modeling_meta_arch_modeling_hook.py 4.15 KB
Newer Older
Peizhao Zhang's avatar
Peizhao Zhang committed
1
2
3
4
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


5
import copy
Peizhao Zhang's avatar
Peizhao Zhang committed
6
7
import unittest

8
import d2go.runner.default_runner as default_runner
Peizhao Zhang's avatar
Peizhao Zhang committed
9
import torch
10
11
from d2go.config import CfgNode
from d2go.modeling import build_model
Peizhao Zhang's avatar
Peizhao Zhang committed
12
from d2go.modeling.meta_arch import modeling_hook as mh
13
from d2go.registry.builtin import META_ARCH_REGISTRY
Peizhao Zhang's avatar
Peizhao Zhang committed
14
15


16
@META_ARCH_REGISTRY.register()
Peizhao Zhang's avatar
Peizhao Zhang committed
17
class TestArch(torch.nn.Module):
18
    def __init__(self, cfg):
Peizhao Zhang's avatar
Peizhao Zhang committed
19
20
21
22
23
24
25
        super().__init__()

    def forward(self, x):
        return x * 2


# create a wrapper of the model that add 1 to the output
26
27
class PlusOneWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module):
Peizhao Zhang's avatar
Peizhao Zhang committed
28
29
30
31
32
33
34
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x) + 1


35
@mh.MODELING_HOOK_REGISTRY.register()
Peizhao Zhang's avatar
Peizhao Zhang committed
36
37
38
39
40
class PlusOneHook(mh.ModelingHook):
    def __init__(self, cfg):
        super().__init__(cfg)

    def apply(self, model: torch.nn.Module) -> torch.nn.Module:
41
        return PlusOneWrapper(model)
Peizhao Zhang's avatar
Peizhao Zhang committed
42
43

    def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        assert isinstance(model, PlusOneWrapper)
        return model.model


# create a wrapper of the model that add 1 to the output
class TimesTwoWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x) * 2


@mh.MODELING_HOOK_REGISTRY.register()
class TimesTwoHook(mh.ModelingHook):
    def __init__(self, cfg):
        super().__init__(cfg)

    def apply(self, model: torch.nn.Module) -> torch.nn.Module:
        return TimesTwoWrapper(model)

    def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
        assert isinstance(model, TimesTwoWrapper)
Peizhao Zhang's avatar
Peizhao Zhang committed
68
69
70
71
72
        return model.model


class TestModelingHook(unittest.TestCase):
    def test_modeling_hook_simple(self):
73
        model = TestArch(None)
Peizhao Zhang's avatar
Peizhao Zhang committed
74
75
76
77
78
        hook = PlusOneHook(None)
        model_with_hook = hook.apply(model)
        self.assertEqual(model_with_hook(2), 5)
        original_model = hook.unapply(model_with_hook)
        self.assertEqual(model, original_model)
79
80
81
82
83
84
85
86
87
88
89

    def test_modeling_hook_cfg(self):
        """Create model with modeling hook using build_model"""
        cfg = CfgNode()
        cfg.MODEL = CfgNode()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL.META_ARCHITECTURE = "TestArch"
        cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
        model = build_model(cfg)
        self.assertEqual(model(2), 10)

90
91
92
93
94
95
        self.assertTrue(hasattr(model, "_modeling_hooks"))
        self.assertTrue(hasattr(model, "unapply_modeling_hooks"))
        orig_model = model.unapply_modeling_hooks()
        self.assertIsInstance(orig_model, TestArch)
        self.assertEqual(orig_model(2), 4)

96
97
98
99
100
101
102
103
104
    def test_modeling_hook_runner(self):
        """Create model with modeling hook from runner"""
        runner = default_runner.Detectron2GoRunner()
        cfg = runner.get_default_cfg()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL.META_ARCHITECTURE = "TestArch"
        cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
        model = runner.build_model(cfg)
        self.assertEqual(model(2), 10)
105
106
107
108
109
110
111

        self.assertTrue(hasattr(model, "_modeling_hooks"))
        self.assertTrue(hasattr(model, "unapply_modeling_hooks"))
        orig_model = model.unapply_modeling_hooks()
        self.assertIsInstance(orig_model, TestArch)
        self.assertEqual(orig_model(2), 4)

112
        default_runner._close_all_tbx_writers()
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

    def test_modeling_hook_copy(self):
        """Create model with modeling hook, the model could be copied"""
        cfg = CfgNode()
        cfg.MODEL = CfgNode()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL.META_ARCHITECTURE = "TestArch"
        cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
        model = build_model(cfg)
        self.assertEqual(model(2), 10)

        model_copy = copy.deepcopy(model)

        orig_model = model.unapply_modeling_hooks()
        self.assertIsInstance(orig_model, TestArch)
        self.assertEqual(orig_model(2), 4)

        orig_model_copy = model_copy.unapply_modeling_hooks()
        self.assertEqual(orig_model_copy(2), 4)