"docs/zh_CN/TrainingService/LocalMode.rst" did not exist on "abc221589c65d75b494407c60a81ca87c3020463"
test_engine.py 2.94 KB
Newer Older
1
2
3
import json
import os
import unittest
4
from pathlib import Path
5

6
import nni.retiarii
7
8
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
9
10
11
from nni.retiarii.execution import set_execution_engine
from nni.retiarii.execution.base import BaseExecutionEngine
from nni.retiarii.execution.python import PurePythonExecutionEngine
12
from nni.retiarii.graph import DebugEvaluator
13
from nni.retiarii.integration import RetiariiAdvisor
14
from nni.runtime.tuner_command_channel.legacy import *
15

16
17
18
class EngineTest(unittest.TestCase):
    def test_codegen(self):
        with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
19
20
            model = Model._load(json.load(f))
            script = model_to_pytorch_script(model)
21
        with open(self.enclosing_dir / 'debug_mnist_pytorch.py') as f:
22
23
24
            reference_script = f.read()
        self.assertEqual(script.strip(), reference_script.strip())

25
    def test_base_execution_engine(self):
26
27
        nni.retiarii.integration_api._advisor = None
        nni.retiarii.execution.api._execution_engine = None
28
29
30
31
32
        advisor = RetiariiAdvisor('ws://_placeholder_')
        advisor._channel = LegacyCommandChannel()
        advisor.default_worker.start()
        advisor.assessor_worker.start()

33
34
35
36
        set_execution_engine(BaseExecutionEngine())
        with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
            model = Model._load(json.load(f))
        submit_models(model, model)
37

38
39
40
        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()
41

42
    def test_py_execution_engine(self):
43
44
        nni.retiarii.integration_api._advisor = None
        nni.retiarii.execution.api._execution_engine = None
45
46
47
48
49
        advisor = RetiariiAdvisor('ws://_placeholder_')
        advisor._channel = LegacyCommandChannel()
        advisor.default_worker.start()
        advisor.assessor_worker.start()

50
51
52
53
54
55
56
57
58
59
60
61
62
        set_execution_engine(PurePythonExecutionEngine())
        model = Model._load({
            '_model': {
                'inputs': None,
                'outputs': None,
                'nodes': {
                    'layerchoice_1': {
                        'operation': {'type': 'LayerChoice', 'parameters': {'candidates': ['0', '1']}}
                    }
                },
                'edges': []
            }
        })
63
        model.evaluator = DebugEvaluator()
64
        model.python_class = object
65
66
67
68
69
70
        submit_models(model, model)

        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()

71
72
73
    def setUp(self) -> None:
        self.enclosing_dir = Path(__file__).parent
        os.makedirs(self.enclosing_dir / 'generated', exist_ok=True)
74
        _set_out_file(open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb'))
75
76

    def tearDown(self) -> None:
77
        _get_out_file().close()
78
79
        nni.retiarii.execution.api._execution_engine = None
        nni.retiarii.integration_api._advisor = None