"vscode:/vscode.git/clone" did not exist on "8930d23d80b250b46f822473cf6d4e9e3af8c4de"
test_graph.py 931 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import json
from pathlib import Path
import sys

from nni.retiarii import *


json_files = [
    'mnist-tensorflow.json'
]


def test_model_load_dump():
    for json_file in json_files:
        path = Path(__file__).parent / json_file
        _test_file(path)


def _test_file(json_path):
    orig_ir = json.load(json_path.open())
    model = Model._load(orig_ir)
    dump_ir = model._dump()

    # add default values to JSON, so we can compare with `==`
    for graph_name, graph in orig_ir.items():
26
27
        if graph_name == '_training_config':
            continue
28
29
30
31
        if 'inputs' not in graph:
            graph['inputs'] = None
        if 'outputs' not in graph:
            graph['outputs'] = None
32
33
34
35

    # debug output
    #json.dump(orig_ir, open('_orig.json', 'w'), indent=4)
    #json.dump(dump_ir, open('_dump.json', 'w'), indent=4)
36
37
38
39
40
    assert orig_ir == dump_ir


if __name__ == '__main__':
    test_model_load_dump()