test_evaluator.py 3.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

import pytest

import torch

from nni.algorithms.compression.v2.pytorch.utils.evaluator import (
    TensorHook,
    ForwardHook,
    BackwardHook,
)

16
17
18
19
20
21
22
from ..assets.device import device
from ..assets.simple_mnist import (
    SimpleLightningModel,
    SimpleTorchModel,
    create_lighting_evaluator,
    create_pytorch_evaluator
)
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73


optimizer_before_step_flag = False
optimizer_after_step_flag = False
loss_flag = False

def optimizer_before_step_patch():
    global optimizer_before_step_flag
    optimizer_before_step_flag = True

def optimizer_after_step_patch():
    global optimizer_after_step_flag
    optimizer_after_step_flag = True

def loss_patch(t: torch.Tensor):
    global loss_flag
    loss_flag = True
    return t

def tensor_hook_factory(buffer: list):
    def hook_func(t: torch.Tensor):
        buffer.append(True)
    return hook_func

def forward_hook_factory(buffer: list):
    def hook_func(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
        buffer.append(True)
    return hook_func

def backward_hook_factory(buffer: list):
    def hook_func(module: torch.nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor):
        buffer.append(True)
    return hook_func

def reset_flags():
    global optimizer_before_step_flag, optimizer_after_step_flag, loss_flag
    optimizer_before_step_flag = False
    optimizer_after_step_flag = False
    loss_flag = False

def assert_flags():
    global optimizer_before_step_flag, optimizer_after_step_flag, loss_flag
    assert optimizer_before_step_flag, 'Evaluator patch optimizer before step failed.'
    assert optimizer_after_step_flag, 'Evaluator patch optimizer after step failed.'
    assert loss_flag, 'Evaluator patch loss failed.'


@pytest.mark.parametrize("evaluator_type", ['lightning', 'pytorch'])
def test_evaluator(evaluator_type: str):
    if evaluator_type == 'lightning':
        model = SimpleLightningModel()
74
75
        evaluator = create_lighting_evaluator()
        evaluator._init_optimizer_helpers(model)
76
77
78
79
80
81
        evaluator.bind_model(model)
        tensor_hook = TensorHook(model.model.conv1.weight, 'model.conv1.weight', tensor_hook_factory)
        forward_hook = ForwardHook(model.model.conv1, 'model.conv1', forward_hook_factory)
        backward_hook = BackwardHook(model.model.conv1, 'model.conv1', backward_hook_factory)
    elif evaluator_type == 'pytorch':
        model = SimpleTorchModel().to(device)
82
83
        evaluator = create_pytorch_evaluator(model)
        evaluator._init_optimizer_helpers(model)
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        evaluator.bind_model(model)
        tensor_hook = TensorHook(model.conv1.weight, 'conv1.weight', tensor_hook_factory)
        forward_hook = ForwardHook(model.conv1, 'conv1', forward_hook_factory)
        backward_hook = BackwardHook(model.conv1, 'conv1', backward_hook_factory)
    else:
        raise ValueError(f'wrong evaluator_type: {evaluator_type}')

    # test train with patch & hook
    reset_flags()
    evaluator.patch_loss(loss_patch)
    evaluator.patch_optimizer_step([optimizer_before_step_patch], [optimizer_after_step_patch])
    evaluator.register_hooks([tensor_hook, forward_hook, backward_hook])

    evaluator.train(max_steps=1)
    assert_flags()
    assert all([len(hook.buffer) == 1 for hook in [tensor_hook, forward_hook, backward_hook]])

    # test finetune with patch & hook
    reset_flags()
    evaluator.remove_all_hooks()
    evaluator.register_hooks([tensor_hook, forward_hook, backward_hook])

    evaluator.finetune()
    assert_flags()
108
    assert all([len(hook.buffer) == 50 for hook in [tensor_hook, forward_hook, backward_hook]])