test_mutator.py 2.44 KB
Newer Older
1
2
3
import json
from pathlib import Path

4
from nni.common.framework import get_default_framework, set_default_framework
5
6
from nni.retiarii import *

7
original_framework = get_default_framework()
8

9
10
11
12
max_pool = Operation.new('MaxPool2D', {'pool_size': 2})
avg_pool = Operation.new('AveragePooling2D', {'pool_size': 2})
global_pool = Operation.new('GlobalAveragePooling2D')

13

Yuge Zhang's avatar
Yuge Zhang committed
14
def setup_module(module):
15
    set_default_framework('tensorflow')
Yuge Zhang's avatar
Yuge Zhang committed
16
17
18


def teardown_module(module):
19
    set_default_framework(original_framework)
Yuge Zhang's avatar
Yuge Zhang committed
20
21


22
23
24
25
26
27
28
29
30
31
32
class DebugSampler(Sampler):
    def __init__(self):
        self.iteration = 0

    def choice(self, candidates, mutator, model, index):
        idx = (self.iteration + index) % len(candidates)
        return candidates[idx]

    def mutation_start(self, mutator, model):
        self.iteration += 1

33

34
35
36
37
38
39
40
41
42
43
class DebugMutator(Mutator):
    def mutate(self, model):
        ops = [max_pool, avg_pool, global_pool]

        pool1 = model.graphs['stem'].get_node_by_name('pool1')
        pool1.update_operation(self.choice(ops))

        pool2 = model.graphs['stem'].get_node_by_name('pool2')
        pool2.update_operation(self.choice(ops))

44

45
46
47
48
49
50
51
52
53
54
55
sampler = DebugSampler()
mutator = DebugMutator()
mutator.bind_sampler(sampler)


json_path = Path(__file__).parent / 'mnist-tensorflow.json'
ir = json.load(json_path.open())
model0 = Model._load(ir)


def test_dry_run():
QuanluZhang's avatar
QuanluZhang committed
56
    candidates, _ = mutator.dry_run(model0)
57
58
59
60
61
62
63
64
65
66
67
68
    assert len(candidates) == 2
    assert candidates[0] == [max_pool, avg_pool, global_pool]
    assert candidates[1] == [max_pool, avg_pool, global_pool]


def test_mutation():
    model1 = mutator.apply(model0)
    assert _get_pools(model1) == (avg_pool, global_pool)

    model2 = mutator.apply(model1)
    assert _get_pools(model2) == (global_pool, max_pool)

69
70
71
72
73
74
75
76
    assert len(model2.history) == 2
    assert model2.history[0].from_ == model0
    assert model2.history[0].to == model1
    assert model2.history[1].from_ == model1
    assert model2.history[1].to == model2
    assert model2.history[0].mutator == mutator
    assert model2.history[1].mutator == mutator

77
78
79
    assert _get_pools(model0) == (max_pool, max_pool)
    assert _get_pools(model1) == (avg_pool, global_pool)

80

81
82
83
84
85
86
87
def _get_pools(model):
    pool1 = model.graphs['stem'].get_node_by_name('pool1').operation
    pool2 = model.graphs['stem'].get_node_by_name('pool2').operation
    return pool1, pool2


if __name__ == '__main__':
Yuge Zhang's avatar
Yuge Zhang committed
88
    setup_module(None)
89
90
    test_dry_run()
    test_mutation()
Yuge Zhang's avatar
Yuge Zhang committed
91
    teardown_module(None)