"include/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "1bd15b9447e0925e6704cfce6caa3ba4fa9b7758"
test_modeling_meta_arch_modeling_hook.py 4.62 KB
Newer Older
Peizhao Zhang's avatar
Peizhao Zhang committed
1
2
3
4
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


5
import copy
Peizhao Zhang's avatar
Peizhao Zhang committed
6
import unittest
7
from typing import List
Peizhao Zhang's avatar
Peizhao Zhang committed
8

9
import d2go.runner.default_runner as default_runner
Peizhao Zhang's avatar
Peizhao Zhang committed
10
import torch
11
from d2go.config import CfgNode
Yanghan Wang's avatar
Yanghan Wang committed
12
from d2go.modeling import modeling_hook as mh
13
from d2go.modeling.api import build_d2go_model, D2GoModelBuildResult
Mircea Cimpoi's avatar
Mircea Cimpoi committed
14
from d2go.registry.builtin import META_ARCH_REGISTRY, MODELING_HOOK_REGISTRY
Peizhao Zhang's avatar
Peizhao Zhang committed
15
16


17
@META_ARCH_REGISTRY.register()
Peizhao Zhang's avatar
Peizhao Zhang committed
18
class TestArch(torch.nn.Module):
19
    def __init__(self, cfg):
Peizhao Zhang's avatar
Peizhao Zhang committed
20
21
22
23
24
25
26
        super().__init__()

    def forward(self, x):
        return x * 2


# create a wrapper of the model that add 1 to the output
27
28
class PlusOneWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module):
Peizhao Zhang's avatar
Peizhao Zhang committed
29
30
31
32
33
34
35
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x) + 1


Mircea Cimpoi's avatar
Mircea Cimpoi committed
36
@MODELING_HOOK_REGISTRY.register()
Peizhao Zhang's avatar
Peizhao Zhang committed
37
38
39
40
41
class PlusOneHook(mh.ModelingHook):
    def __init__(self, cfg):
        super().__init__(cfg)

    def apply(self, model: torch.nn.Module) -> torch.nn.Module:
42
        return PlusOneWrapper(model)
Peizhao Zhang's avatar
Peizhao Zhang committed
43
44

    def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        assert isinstance(model, PlusOneWrapper)
        return model.model


# create a wrapper of the model that add 1 to the output
class TimesTwoWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x) * 2


Mircea Cimpoi's avatar
Mircea Cimpoi committed
59
@MODELING_HOOK_REGISTRY.register()
60
61
62
63
64
65
66
67
68
class TimesTwoHook(mh.ModelingHook):
    def __init__(self, cfg):
        super().__init__(cfg)

    def apply(self, model: torch.nn.Module) -> torch.nn.Module:
        return TimesTwoWrapper(model)

    def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
        assert isinstance(model, TimesTwoWrapper)
Peizhao Zhang's avatar
Peizhao Zhang committed
69
70
71
72
73
        return model.model


class TestModelingHook(unittest.TestCase):
    def test_modeling_hook_simple(self):
74
        model = TestArch(None)
Peizhao Zhang's avatar
Peizhao Zhang committed
75
76
77
78
79
        hook = PlusOneHook(None)
        model_with_hook = hook.apply(model)
        self.assertEqual(model_with_hook(2), 5)
        original_model = hook.unapply(model_with_hook)
        self.assertEqual(model, original_model)
80
81
82
83
84
85
86
87

    def test_modeling_hook_cfg(self):
        """Create model with modeling hook using build_model"""
        cfg = CfgNode()
        cfg.MODEL = CfgNode()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL.META_ARCHITECTURE = "TestArch"
        cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
88
89
90
91
92

        model_info: D2GoModelBuildResult = build_d2go_model(cfg)
        model: torch.nn.Module = model_info.model
        modeling_hooks: List[mh.ModelingHook] = model_info.modeling_hooks

93
        self.assertEqual(model(2), 10)
94
        self.assertEqual(len(modeling_hooks), 2)
95

96
97
98
99
100
101
        self.assertTrue(hasattr(model, "_modeling_hooks"))
        self.assertTrue(hasattr(model, "unapply_modeling_hooks"))
        orig_model = model.unapply_modeling_hooks()
        self.assertIsInstance(orig_model, TestArch)
        self.assertEqual(orig_model(2), 4)

102
103
104
105
106
107
108
109
110
    def test_modeling_hook_runner(self):
        """Create model with modeling hook from runner"""
        runner = default_runner.Detectron2GoRunner()
        cfg = runner.get_default_cfg()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL.META_ARCHITECTURE = "TestArch"
        cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
        model = runner.build_model(cfg)
        self.assertEqual(model(2), 10)
111
112
113
114
115
116
117

        self.assertTrue(hasattr(model, "_modeling_hooks"))
        self.assertTrue(hasattr(model, "unapply_modeling_hooks"))
        orig_model = model.unapply_modeling_hooks()
        self.assertIsInstance(orig_model, TestArch)
        self.assertEqual(orig_model(2), 4)

118
        default_runner._close_all_tbx_writers()
119
120
121
122
123
124
125
126

    def test_modeling_hook_copy(self):
        """Create model with modeling hook, the model could be copied"""
        cfg = CfgNode()
        cfg.MODEL = CfgNode()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL.META_ARCHITECTURE = "TestArch"
        cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
127
128
129
130
131

        model_info: D2GoModelBuildResult = build_d2go_model(cfg)
        model: torch.nn.Module = model_info.model
        modeling_hooks: List[mh.ModelingHook] = model_info.modeling_hooks

132
        self.assertEqual(model(2), 10)
133
        self.assertEqual(len(modeling_hooks), 2)
134
135
136
137
138
139
140
141
142

        model_copy = copy.deepcopy(model)

        orig_model = model.unapply_modeling_hooks()
        self.assertIsInstance(orig_model, TestArch)
        self.assertEqual(orig_model(2), 4)

        orig_model_copy = model_copy.unapply_modeling_hooks()
        self.assertEqual(orig_model_copy(2), 4)