test_modeling_meta_arch_modeling_hook.py 1.2 KB
Newer Older
Peizhao Zhang's avatar
Peizhao Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import unittest

import torch
from d2go.modeling.meta_arch import modeling_hook as mh


class TestArch(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * 2


# create a wrapper of the model that add 1 to the output
class Wrapper(torch.nn.Module):
    def __init__(self, model: TestArch):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x) + 1


class PlusOneHook(mh.ModelingHook):
    def __init__(self, cfg):
        super().__init__(cfg)

    def apply(self, model: torch.nn.Module) -> torch.nn.Module:
        return Wrapper(model)

    def unapply(self, model: torch.nn.Module) -> torch.nn.Module:
        assert isinstance(model, Wrapper)
        return model.model


class TestModelingHook(unittest.TestCase):
    def test_modeling_hook_simple(self):
        model = TestArch()
        hook = PlusOneHook(None)
        model_with_hook = hook.apply(model)
        self.assertEqual(model_with_hook(2), 5)
        original_model = hook.unapply(model_with_hook)
        self.assertEqual(model, original_model)