Unverified Commit bc0f8f33 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Refactor code hierarchy part 3: Unit test (#3037)

parent 80b6cb3b
...@@ -5,9 +5,10 @@ import json ...@@ -5,9 +5,10 @@ import json
from io import BytesIO from io import BytesIO
from unittest import TestCase, main from unittest import TestCase, main
import nni.protocol from nni.runtime import protocol
from nni.msg_dispatcher import MsgDispatcher from nni.runtime import msg_dispatcher_base
from nni.protocol import CommandType, send, receive from nni.runtime.msg_dispatcher import MsgDispatcher
from nni.runtime.protocol import CommandType, send, receive
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import extract_scalar_reward from nni.utils import extract_scalar_reward
...@@ -44,15 +45,15 @@ _out_buf = BytesIO() ...@@ -44,15 +45,15 @@ _out_buf = BytesIO()
def _reverse_io(): def _reverse_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
nni.protocol._out_file = _in_buf protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf protocol._in_file = _out_buf
def _restore_io(): def _restore_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
nni.protocol._in_file = _in_buf protocol._in_file = _in_buf
nni.protocol._out_file = _out_buf protocol._out_file = _out_buf
class MsgDispatcherTestCase(TestCase): class MsgDispatcherTestCase(TestCase):
...@@ -68,7 +69,7 @@ class MsgDispatcherTestCase(TestCase): ...@@ -68,7 +69,7 @@ class MsgDispatcherTestCase(TestCase):
tuner = NaiveTuner() tuner = NaiveTuner()
dispatcher = MsgDispatcher(tuner) dispatcher = MsgDispatcher(tuner)
nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False msg_dispatcher_base._worker_fast_exit_on_terminate = False
dispatcher.run() dispatcher.run()
e = dispatcher.worker_exceptions[0] e = dispatcher.worker_exceptions[0]
......
...@@ -8,12 +8,12 @@ from unittest import TestCase, main ...@@ -8,12 +8,12 @@ from unittest import TestCase, main
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.classic_nas import get_and_apply_next_architecture from nni.algorithms.nas.pytorch.classic_nas import get_and_apply_next_architecture
from nni.nas.pytorch.darts import DartsMutator from nni.algorithms.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.enas import EnasMutator from nni.algorithms.nas.pytorch.enas import EnasMutator
from nni.nas.pytorch.fixed import apply_fixed_architecture from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.mutables import LayerChoice from nni.nas.pytorch.mutables import LayerChoice
from nni.nas.pytorch.random import RandomMutator from nni.algorithms.nas.pytorch.random import RandomMutator
from nni.nas.pytorch.utils import _reset_global_mutable_counting from nni.nas.pytorch.utils import _reset_global_mutable_counting
......
...@@ -6,15 +6,15 @@ from unittest import TestCase, main ...@@ -6,15 +6,15 @@ from unittest import TestCase, main
from copy import deepcopy from copy import deepcopy
import torch import torch
from nni.networkmorphism_tuner.graph import graph_to_json, json_to_graph from nni.algorithms.hpo.networkmorphism_tuner.graph import graph_to_json, json_to_graph
from nni.networkmorphism_tuner.graph_transformer import ( from nni.algorithms.hpo.networkmorphism_tuner.graph_transformer import (
to_deeper_graph, to_deeper_graph,
to_skip_connection_graph, to_skip_connection_graph,
to_wider_graph, to_wider_graph,
) )
from nni.networkmorphism_tuner.layers import layer_description_extractor from nni.algorithms.hpo.networkmorphism_tuner.layers import layer_description_extractor
from nni.networkmorphism_tuner.networkmorphism_tuner import NetworkMorphismTuner from nni.algorithms.hpo.networkmorphism_tuner.networkmorphism_tuner import NetworkMorphismTuner
from nni.networkmorphism_tuner.nn import CnnGenerator from nni.algorithms.hpo.networkmorphism_tuner.nn import CnnGenerator
class NetworkMorphismTestCase(TestCase): class NetworkMorphismTestCase(TestCase):
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import nni.protocol from nni.runtime import protocol
from nni.protocol import CommandType, send, receive from nni.runtime.protocol import CommandType, send, receive
from io import BytesIO from io import BytesIO
from unittest import TestCase, main from unittest import TestCase, main
def _prepare_send(): def _prepare_send():
nni.protocol._out_file = BytesIO() protocol._out_file = BytesIO()
return nni.protocol._out_file return protocol._out_file
def _prepare_receive(data): def _prepare_receive(data):
nni.protocol._in_file = BytesIO(data) protocol._in_file = BytesIO(data)
class ProtocolTestCase(TestCase): class ProtocolTestCase(TestCase):
......
...@@ -7,11 +7,15 @@ import torch.nn as nn ...@@ -7,11 +7,15 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.data import torch.utils.data
import math import math
import sys
import unittest
from unittest import TestCase, main from unittest import TestCase, main
from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \ from nni.algorithms.compression.pytorch.pruning import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \
L2FilterPruner, AGPPruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner, \ L2FilterPruner, AGPPruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner, \
TaylorFOWeightFilterPruner, NetAdaptPruner, SimulatedAnnealingPruner, ADMMPruner, \ TaylorFOWeightFilterPruner, NetAdaptPruner, SimulatedAnnealingPruner, ADMMPruner, \
AutoCompressPruner, AMCPruner AutoCompressPruner, AMCPruner
sys.path.append(os.path.dirname(__file__))
from models.pytorch_models.mobilenet import MobileNet from models.pytorch_models.mobilenet import MobileNet
def validate_sparsity(wrapper, sparsity, bias=False): def validate_sparsity(wrapper, sparsity, bias=False):
...@@ -229,7 +233,7 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl ...@@ -229,7 +233,7 @@ def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'tayl
if os.path.exists(f): if os.path.exists(f):
os.remove(f) os.remove(f)
def test_agp(pruning_algorithm): def _test_agp(pruning_algorithm):
model = Model() model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
config_list = prune_config['agp']['config_list'] config_list = prune_config['agp']['config_list']
...@@ -260,6 +264,7 @@ class SimpleDataset: ...@@ -260,6 +264,7 @@ class SimpleDataset:
def __len__(self): def __len__(self):
return 1000 return 1000
@unittest.skipIf(torch.__version__ >= '1.6.0', 'not supported')
class PrunerTestCase(TestCase): class PrunerTestCase(TestCase):
def test_pruners(self): def test_pruners(self):
pruners_test(bias=True) pruners_test(bias=True)
...@@ -269,11 +274,11 @@ class PrunerTestCase(TestCase): ...@@ -269,11 +274,11 @@ class PrunerTestCase(TestCase):
def test_agp_pruner(self): def test_agp_pruner(self):
for pruning_algorithm in ['l1', 'l2', 'taylorfo', 'apoz']: for pruning_algorithm in ['l1', 'l2', 'taylorfo', 'apoz']:
test_agp(pruning_algorithm) _test_agp(pruning_algorithm)
for pruning_algorithm in ['level']: for pruning_algorithm in ['level']:
prune_config['agp']['config_list'][0]['op_types'] = ['default'] prune_config['agp']['config_list'][0]['op_types'] = ['default']
test_agp(pruning_algorithm) _test_agp(pruning_algorithm)
def testAMC(self): def testAMC(self):
model = MobileNet(n_class=10) model = MobileNet(n_class=10)
......
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
os.environ['NNI_PLATFORM'] = 'unittest' os.environ['NNI_PLATFORM'] = 'unittest'
import nni import nni
import nni.platform.test as test_platform import nni.runtime.platform.test as test_platform
import nni.trial import nni.trial
from unittest import TestCase, main from unittest import TestCase, main
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import nni import nni
import nni.platform.test as test_platform import nni.runtime.platform.test as test_platform
import nni.trial import nni.trial
import numpy as np import numpy as np
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.tools import annotation
import ast
import json
from pathlib import Path
import shutil
import tempfile
import pytest
cwd = Path(__file__).parent
shutil.rmtree(cwd / '_generated', ignore_errors=True)
shutil.copytree(cwd / 'testcase/annotated', cwd / '_generated/annotated')
def test_search_space_generator():
search_space = annotation.generate_search_space(cwd / '_generated/annotated')
expected = json.load((cwd / 'testcase/searchspace.json').open())
assert search_space == expected
def test_code_generator():
src_dir = cwd / 'testcase/usercode'
dst_dir = cwd / '_generated/usercode'
code_dir = annotation.expand_annotations(src_dir, dst_dir, nas_mode='classic_mode')
assert Path(code_dir) == dst_dir
expect_dir = cwd / 'testcase/annotated'
_assert_source_equal(dst_dir, expect_dir, 'dir/simple.py')
_assert_source_equal(dst_dir, expect_dir, 'mnist.py')
_assert_source_equal(dst_dir, expect_dir, 'nas.py')
assert (src_dir / 'nonpy.txt').read_text() == (dst_dir / 'nonpy.txt').read_text()
def test_annotation_detecting():
src_dir = cwd / 'testcase/usercode/non_annotation'
code_dir = annotation.expand_annotations(src_dir, tempfile.mkdtemp())
assert Path(code_dir) == src_dir
def _assert_source_equal(dir1, dir2, file_name):
ast1 = ast.parse((dir1 / file_name).read_text())
ast2 = ast.parse((dir2 / file_name).read_text())
_assert_ast_equal(ast1, ast2)
def _assert_ast_equal(ast1, ast2):
assert type(ast1) is type(ast2)
if isinstance(ast1, ast.AST):
assert sorted(ast1._fields) == sorted(ast2._fields)
for field_name in ast1._fields:
field1 = getattr(ast1, field_name)
field2 = getattr(ast2, field_name)
_assert_ast_equal(field1, field2)
elif isinstance(ast1, list):
assert len(ast1) == len(ast2)
for item1, item2 in zip(ast1, ast2):
_assert_ast_equal(item1, item2)
else:
assert ast1 == ast2
if __name__ == '__main__':
pytest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment