"examples/llava_quant.py" did not exist on "727172e91ddfd3f5f50efd7cf38a141168186335"
test_modeling_meta_arch_modeling_hook.py 2.93 KB
Newer Older
Peizhao Zhang's avatar
Peizhao Zhang committed
1
2
3
4
5
6
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import unittest

7
import d2go.runner.default_runner as default_runner
Peizhao Zhang's avatar
Peizhao Zhang committed
8
import torch
9
10
from d2go.config import CfgNode
from d2go.modeling import build_model
Peizhao Zhang's avatar
Peizhao Zhang committed
11
from d2go.modeling.meta_arch import modeling_hook as mh
12
from detectron2.modeling import META_ARCH_REGISTRY
Peizhao Zhang's avatar
Peizhao Zhang committed
13
14


15
@META_ARCH_REGISTRY.register()
Peizhao Zhang's avatar
Peizhao Zhang committed
16
class TestArch(torch.nn.Module):
17
    def __init__(self, cfg):
Peizhao Zhang's avatar
Peizhao Zhang committed
18
19
20
21
22
23
24
        super().__init__()

    def forward(self, x):
        return x * 2


# create a wrapper of the model that add 1 to the output
25
26
class PlusOneWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module):
Peizhao Zhang's avatar
Peizhao Zhang committed
27
28
29
30
31
32
33
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x) + 1


34
@mh.MODELING_HOOK_REGISTRY.register()
Peizhao Zhang's avatar
Peizhao Zhang committed
35
36
37
38
39
class PlusOneHook(mh.ModelingHook):
    def __init__(self, cfg):
        super().__init__(cfg)

    def apply(self, model: torch.nn.Module) -> torch.nn.Module:
40
        return PlusOneWrapper(model)
Peizhao Zhang's avatar
Peizhao Zhang committed
41
42

    def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        assert isinstance(model, PlusOneWrapper)
        return model.model


# create a wrapper of the model that add 1 to the output
class TimesTwoWrapper(torch.nn.Module):
    def __init__(self, model: torch.nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x) * 2


@mh.MODELING_HOOK_REGISTRY.register()
class TimesTwoHook(mh.ModelingHook):
    def __init__(self, cfg):
        super().__init__(cfg)

    def apply(self, model: torch.nn.Module) -> torch.nn.Module:
        return TimesTwoWrapper(model)

    def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
        assert isinstance(model, TimesTwoWrapper)
Peizhao Zhang's avatar
Peizhao Zhang committed
67
68
69
70
71
        return model.model


class TestModelingHook(unittest.TestCase):
    def test_modeling_hook_simple(self):
72
        model = TestArch(None)
Peizhao Zhang's avatar
Peizhao Zhang committed
73
74
75
76
77
        hook = PlusOneHook(None)
        model_with_hook = hook.apply(model)
        self.assertEqual(model_with_hook(2), 5)
        original_model = hook.unapply(model_with_hook)
        self.assertEqual(model, original_model)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    def test_modeling_hook_cfg(self):
        """Create model with modeling hook using build_model"""
        cfg = CfgNode()
        cfg.MODEL = CfgNode()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL.META_ARCHITECTURE = "TestArch"
        cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
        model = build_model(cfg)
        self.assertEqual(model(2), 10)

    def test_modeling_hook_runner(self):
        """Create model with modeling hook from runner"""
        runner = default_runner.Detectron2GoRunner()
        cfg = runner.get_default_cfg()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL.META_ARCHITECTURE = "TestArch"
        cfg.MODEL.MODELING_HOOKS = ["PlusOneHook", "TimesTwoHook"]
        model = runner.build_model(cfg)
        self.assertEqual(model(2), 10)
        default_runner._close_all_tbx_writers()