test_mutator.py 2.46 KB
Newer Older
1
2
3
4
5
6
7
8
import json
from pathlib import Path
import sys

from nni.retiarii import *

# FIXME
import nni.retiarii.debug_configs
Yuge Zhang's avatar
Yuge Zhang committed
9
original_framework = nni.retiarii.debug_configs.framework
10

11
12
13
14
max_pool = Operation.new('MaxPool2D', {'pool_size': 2})
avg_pool = Operation.new('AveragePooling2D', {'pool_size': 2})
global_pool = Operation.new('GlobalAveragePooling2D')

15

Yuge Zhang's avatar
Yuge Zhang committed
16
17
18
19
20
21
22
23
def setup_module(module):
    nni.retiarii.debug_configs.framework = 'tensorflow'


def teardown_module(module):
    nni.retiarii.debug_configs.framework = original_framework


24
25
26
27
28
29
30
31
32
33
34
class DebugSampler(Sampler):
    def __init__(self):
        self.iteration = 0

    def choice(self, candidates, mutator, model, index):
        idx = (self.iteration + index) % len(candidates)
        return candidates[idx]

    def mutation_start(self, mutator, model):
        self.iteration += 1

35

36
37
38
39
40
41
42
43
44
45
class DebugMutator(Mutator):
    def mutate(self, model):
        ops = [max_pool, avg_pool, global_pool]

        pool1 = model.graphs['stem'].get_node_by_name('pool1')
        pool1.update_operation(self.choice(ops))

        pool2 = model.graphs['stem'].get_node_by_name('pool2')
        pool2.update_operation(self.choice(ops))

46

47
48
49
50
51
52
53
54
55
56
57
sampler = DebugSampler()
mutator = DebugMutator()
mutator.bind_sampler(sampler)


json_path = Path(__file__).parent / 'mnist-tensorflow.json'
ir = json.load(json_path.open())
model0 = Model._load(ir)


def test_dry_run():
QuanluZhang's avatar
QuanluZhang committed
58
    candidates, _ = mutator.dry_run(model0)
59
60
61
62
63
64
65
66
67
68
69
70
    assert len(candidates) == 2
    assert candidates[0] == [max_pool, avg_pool, global_pool]
    assert candidates[1] == [max_pool, avg_pool, global_pool]


def test_mutation():
    model1 = mutator.apply(model0)
    assert _get_pools(model1) == (avg_pool, global_pool)

    model2 = mutator.apply(model1)
    assert _get_pools(model2) == (global_pool, max_pool)

71
72
73
74
75
76
77
78
    assert len(model2.history) == 2
    assert model2.history[0].from_ == model0
    assert model2.history[0].to == model1
    assert model2.history[1].from_ == model1
    assert model2.history[1].to == model2
    assert model2.history[0].mutator == mutator
    assert model2.history[1].mutator == mutator

79
80
81
    assert _get_pools(model0) == (max_pool, max_pool)
    assert _get_pools(model1) == (avg_pool, global_pool)

82

83
84
85
86
87
88
89
def _get_pools(model):
    pool1 = model.graphs['stem'].get_node_by_name('pool1').operation
    pool2 = model.graphs['stem'].get_node_by_name('pool2').operation
    return pool1, pool2


if __name__ == '__main__':
Yuge Zhang's avatar
Yuge Zhang committed
90
    setup_module(None)
91
92
    test_dry_run()
    test_mutation()
Yuge Zhang's avatar
Yuge Zhang committed
93
    teardown_module(None)