test_engine.py 1.98 KB
Newer Older
1
2
3
4
5
import json
import os
import sys
import threading
import unittest
6
from pathlib import Path
7

8
import nni
9
10
11
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.integration import RetiariiAdvisor, register_advisor
12
from nni.retiarii.evaluator.pytorch import PyTorchImageClassificationTrainer
13
14
15
from nni.retiarii.utils import import_


16
@unittest.skip('Skipped in this version')
17
18
19
20
21
22
23
24
25
26
class CodeGenTest(unittest.TestCase):
    def test_mnist_example_pytorch(self):
        with open('mnist_pytorch.json') as f:
            model = Model._load(json.load(f))
            script = model_to_pytorch_script(model)
        with open('debug_mnist_pytorch.py') as f:
            reference_script = f.read()
        self.assertEqual(script.strip(), reference_script.strip())


27
@unittest.skip('Skipped in this version')
28
29
class TrainerTest(unittest.TestCase):
    def test_trainer(self):
30
        sys.path.insert(0, Path(__file__).parent.as_posix())
31
32
33
        Model = import_('debug_mnist_pytorch._model')
        trainer = PyTorchImageClassificationTrainer(
            Model(),
34
            dataset_kwargs={'root': (Path(__file__).parent / 'data' / 'mnist').as_posix(), 'download': True},
35
36
37
38
39
40
41
            dataloader_kwargs={'batch_size': 32},
            optimizer_kwargs={'lr': 1e-3},
            trainer_kwargs={'max_epochs': 1}
        )
        trainer.fit()


42
@unittest.skip('Skipped in this version')
43
44
45
46
47
class EngineTest(unittest.TestCase):

    def test_submit_models(self):
        os.makedirs('generated', exist_ok=True)
        from nni.runtime import protocol
48
49
        protocol._out_file = open(Path(__file__).parent / 'generated/debug_protocol_out_file.py', 'wb')
        advisor = RetiariiAdvisor()
50
51
52
53
54
55
56
57
58
59
        with open('mnist_pytorch.json') as f:
            model = Model._load(json.load(f))
        submit_models(model, model)

        advisor.stopping = True
        advisor.default_worker.join()
        advisor.assessor_worker.join()

    def test_execution_engine(self):
        pass