test_mutator.py 1.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import json
from pathlib import Path
import sys

from nni.retiarii import *

# FIXME
import nni.retiarii.debug_configs
nni.retiarii.debug_configs.framework = 'tensorflow'

11
12
13
14
max_pool = Operation.new('MaxPool2D', {'pool_size': 2})
avg_pool = Operation.new('AveragePooling2D', {'pool_size': 2})
global_pool = Operation.new('GlobalAveragePooling2D')

15
16
17
18
19
20
21
22
23
24
25
26

class DebugSampler(Sampler):
    def __init__(self):
        self.iteration = 0

    def choice(self, candidates, mutator, model, index):
        idx = (self.iteration + index) % len(candidates)
        return candidates[idx]

    def mutation_start(self, mutator, model):
        self.iteration += 1

27

28
29
30
31
32
33
34
35
36
37
class DebugMutator(Mutator):
    def mutate(self, model):
        ops = [max_pool, avg_pool, global_pool]

        pool1 = model.graphs['stem'].get_node_by_name('pool1')
        pool1.update_operation(self.choice(ops))

        pool2 = model.graphs['stem'].get_node_by_name('pool2')
        pool2.update_operation(self.choice(ops))

38

39
40
41
42
43
44
45
46
47
48
49
sampler = DebugSampler()
mutator = DebugMutator()
mutator.bind_sampler(sampler)


json_path = Path(__file__).parent / 'mnist-tensorflow.json'
ir = json.load(json_path.open())
model0 = Model._load(ir)


def test_dry_run():
QuanluZhang's avatar
QuanluZhang committed
50
    candidates, _ = mutator.dry_run(model0)
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    assert len(candidates) == 2
    assert candidates[0] == [max_pool, avg_pool, global_pool]
    assert candidates[1] == [max_pool, avg_pool, global_pool]


def test_mutation():
    model1 = mutator.apply(model0)
    assert _get_pools(model1) == (avg_pool, global_pool)

    model2 = mutator.apply(model1)
    assert _get_pools(model2) == (global_pool, max_pool)

    assert model2.history == [model0, model1]
    assert _get_pools(model0) == (max_pool, max_pool)
    assert _get_pools(model1) == (avg_pool, global_pool)

67

68
69
70
71
72
73
74
75
76
def _get_pools(model):
    pool1 = model.graphs['stem'].get_node_by_name('pool1').operation
    pool2 = model.graphs['stem'].get_node_by_name('pool2').operation
    return pool1, pool2


if __name__ == '__main__':
    test_dry_run()
    test_mutation()