"vscode:/vscode.git/clone" did not exist on "a98a53f5f7c99223972f6805095f75f4c26960ed"
test_engine.py 3 KB
Newer Older
1
2
3
import json
import os
import unittest
4
from pathlib import Path
5

6
import nni.retiarii
7
import nni.retiarii.integration_api
8
9
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
10
11
12
from nni.retiarii.execution import set_execution_engine
from nni.retiarii.execution.base import BaseExecutionEngine
from nni.retiarii.execution.python import PurePythonExecutionEngine
13
from nni.retiarii.graph import DebugEvaluator
14
from nni.retiarii.integration import RetiariiAdvisor
15
from nni.runtime.tuner_command_channel.legacy import *
16

17
18
19
class EngineTest(unittest.TestCase):
    def test_codegen(self):
        with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
20
21
            model = Model._load(json.load(f))
            script = model_to_pytorch_script(model)
22
        with open(self.enclosing_dir / 'debug_mnist_pytorch.py') as f:
23
24
25
            reference_script = f.read()
        self.assertEqual(script.strip(), reference_script.strip())

26
    def test_base_execution_engine(self):
27
28
        nni.retiarii.integration_api._advisor = None
        nni.retiarii.execution.api._execution_engine = None
29
        advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
30
31
32
33
        advisor._channel = LegacyCommandChannel()
        advisor.default_worker.start()
        advisor.assessor_worker.start()

34
35
36
37
        set_execution_engine(BaseExecutionEngine())
        with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
            model = Model._load(json.load(f))
        submit_models(model, model)
38

39
40
41
        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()
42

43
    def test_py_execution_engine(self):
44
45
        nni.retiarii.integration_api._advisor = None
        nni.retiarii.execution.api._execution_engine = None
46
        advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
47
48
49
50
        advisor._channel = LegacyCommandChannel()
        advisor.default_worker.start()
        advisor.assessor_worker.start()

51
52
53
54
55
56
57
58
59
60
61
62
63
        set_execution_engine(PurePythonExecutionEngine())
        model = Model._load({
            '_model': {
                'inputs': None,
                'outputs': None,
                'nodes': {
                    'layerchoice_1': {
                        'operation': {'type': 'LayerChoice', 'parameters': {'candidates': ['0', '1']}}
                    }
                },
                'edges': []
            }
        })
64
        model.evaluator = DebugEvaluator()
65
        model.python_class = object
66
67
68
69
70
71
        submit_models(model, model)

        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()

72
73
74
    def setUp(self) -> None:
        self.enclosing_dir = Path(__file__).parent
        os.makedirs(self.enclosing_dir / 'generated', exist_ok=True)
75
        _set_out_file(open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb'))
76
77

    def tearDown(self) -> None:
78
        _get_out_file().close()
79
80
        nni.retiarii.execution.api._execution_engine = None
        nni.retiarii.integration_api._advisor = None