test_engine.py 3.08 KB
Newer Older
1
2
3
import json
import os
import unittest
4
from pathlib import Path
5

6
import nni.retiarii
7
import nni.retiarii.integration_api
8
9
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
10
11
12
from nni.retiarii.execution import set_execution_engine
from nni.retiarii.execution.base import BaseExecutionEngine
from nni.retiarii.execution.python import PurePythonExecutionEngine
13
from nni.retiarii.graph import DebugEvaluator
14
from nni.retiarii.integration import RetiariiAdvisor
15
from nni.runtime.tuner_command_channel.legacy import *
16

17
18
19
class EngineTest(unittest.TestCase):
    def test_codegen(self):
        with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
20
21
            model = Model._load(json.load(f))
            script = model_to_pytorch_script(model)
22
        with open(self.enclosing_dir / 'debug_mnist_pytorch.py') as f:
23
24
25
            reference_script = f.read()
        self.assertEqual(script.strip(), reference_script.strip())

26
    def test_base_execution_engine(self):
27
28
        nni.retiarii.integration_api._advisor = None
        nni.retiarii.execution.api._execution_engine = None
29
        advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
30
        advisor._advisor_initialized = True
31
32
33
34
        advisor._channel = LegacyCommandChannel()
        advisor.default_worker.start()
        advisor.assessor_worker.start()

35
36
37
38
        set_execution_engine(BaseExecutionEngine())
        with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
            model = Model._load(json.load(f))
        submit_models(model, model)
39

40
41
42
        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()
43

44
    def test_py_execution_engine(self):
45
46
        nni.retiarii.integration_api._advisor = None
        nni.retiarii.execution.api._execution_engine = None
47
        advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
48
        advisor._advisor_initialized = True
49
50
51
52
        advisor._channel = LegacyCommandChannel()
        advisor.default_worker.start()
        advisor.assessor_worker.start()

53
54
55
56
57
58
59
60
61
62
63
64
65
        set_execution_engine(PurePythonExecutionEngine())
        model = Model._load({
            '_model': {
                'inputs': None,
                'outputs': None,
                'nodes': {
                    'layerchoice_1': {
                        'operation': {'type': 'LayerChoice', 'parameters': {'candidates': ['0', '1']}}
                    }
                },
                'edges': []
            }
        })
66
        model.evaluator = DebugEvaluator()
67
        model.python_class = object
68
69
70
71
72
73
        submit_models(model, model)

        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()

74
75
76
    def setUp(self) -> None:
        self.enclosing_dir = Path(__file__).parent
        os.makedirs(self.enclosing_dir / 'generated', exist_ok=True)
77
        _set_out_file(open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb'))
78
79

    def tearDown(self) -> None:
80
        _get_out_file().close()
81
82
        nni.retiarii.execution.api._execution_engine = None
        nni.retiarii.integration_api._advisor = None