"vscode:/vscode.git/clone" did not exist on "1a37d4d2eb60fb98ce877b75d86334f1e6242b19"
test_engine.py 2.39 KB
Newer Older
1
2
3
import json
import os
import unittest
4
from pathlib import Path
5

6
import nni.retiarii
7
8
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
9
10
11
12
from nni.retiarii.execution import set_execution_engine
from nni.retiarii.execution.base import BaseExecutionEngine
from nni.retiarii.execution.python import PurePythonExecutionEngine
from nni.retiarii.integration import RetiariiAdvisor
13
14


15
16
17
class EngineTest(unittest.TestCase):
    def test_codegen(self):
        with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
18
19
            model = Model._load(json.load(f))
            script = model_to_pytorch_script(model)
20
        with open(self.enclosing_dir / 'debug_mnist_pytorch.py') as f:
21
22
23
            reference_script = f.read()
        self.assertEqual(script.strip(), reference_script.strip())

24
25
26
27
28
29
    def test_base_execution_engine(self):
        advisor = RetiariiAdvisor()
        set_execution_engine(BaseExecutionEngine())
        with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
            model = Model._load(json.load(f))
        submit_models(model, model)
30

31
32
33
        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()
34

35
36
    def test_py_execution_engine(self):
        
37
        advisor = RetiariiAdvisor()
38
39
40
41
42
43
44
45
46
47
48
49
50
51
        set_execution_engine(PurePythonExecutionEngine())
        model = Model._load({
            '_model': {
                'inputs': None,
                'outputs': None,
                'nodes': {
                    'layerchoice_1': {
                        'operation': {'type': 'LayerChoice', 'parameters': {'candidates': ['0', '1']}}
                    }
                },
                'edges': []
            }
        })
        model.python_class = object
52
53
54
55
56
57
        submit_models(model, model)

        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()

58
59
60
61
62
63
64
65
66
67
68
    def setUp(self) -> None:
        self.enclosing_dir = Path(__file__).parent
        os.makedirs(self.enclosing_dir / 'generated', exist_ok=True)
        from nni.runtime import protocol
        protocol._out_file = open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb')

    def tearDown(self) -> None:
        from nni.runtime import protocol
        protocol._out_file.close()
        nni.retiarii.execution.api._execution_engine = None
        nni.retiarii.integration_api._advisor = None