Unverified Commit 896c516f authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[UT] move compression ut to an additional folder & speed up simulated annealing search (#4357)

parent 822556b7
...@@ -229,8 +229,11 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -229,8 +229,11 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
if target_sparsity == 0: if target_sparsity == 0:
return [], [] return [], []
low_limit = 0
while True: while True:
random_sparsity = sorted(np.random.uniform(0, 1, len(op_names))) # This is to speed up finding the legal sparsity.
low_limit = (1 - low_limit) * 0.05 + low_limit
random_sparsity = sorted(np.random.uniform(low_limit, 1, len(op_names)))
rescaled_sparsity = self._rescale_sparsity(random_sparsity, target_sparsity, op_names) rescaled_sparsity = self._rescale_sparsity(random_sparsity, target_sparsity, op_names)
if rescaled_sparsity is not None and rescaled_sparsity[0] >= 0 and rescaled_sparsity[-1] < 1: if rescaled_sparsity is not None and rescaled_sparsity[0] >= 0 and rescaled_sparsity[-1] < 1:
break break
......
...@@ -179,14 +179,14 @@ stages: ...@@ -179,14 +179,14 @@ stages:
set -e set -e
cd test cd test
python -m pytest ut --cov-config=.coveragerc \ python -m pytest ut --cov-config=.coveragerc \
--ignore=ut/sdk/test_pruners.py \ --ignore=ut/compression/v1/test_pruners.py \
--ignore=ut/sdk/test_compressor_tf.py \ --ignore=ut/compression/v1/test_compressor_tf.py \
--ignore=ut/sdk/test_compressor_torch.py \ --ignore=ut/compression/v1/test_compressor_torch.py \
--ignore=ut/sdk/test_model_speedup.py --ignore=ut/compression/v1/test_model_speedup.py
python -m pytest ut/sdk/test_pruners.py --cov-config=.coveragerc --cov-append python -m pytest ut/compression/v1/test_pruners.py --cov-config=.coveragerc --cov-append
python -m pytest ut/sdk/test_compressor_tf.py --cov-config=.coveragerc --cov-append python -m pytest ut/compression/v1/test_compressor_tf.py --cov-config=.coveragerc --cov-append
python -m pytest ut/sdk/test_compressor_torch.py --cov-config=.coveragerc --cov-append python -m pytest ut/compression/v1/test_compressor_torch.py --cov-config=.coveragerc --cov-append
python -m pytest ut/sdk/test_model_speedup.py --cov-config=.coveragerc --cov-append python -m pytest ut/compression/v1/test_model_speedup.py --cov-config=.coveragerc --cov-append
cp coverage.xml ../coverage/python.xml cp coverage.xml ../coverage/python.xml
displayName: Python unit test displayName: Python unit test
......
...@@ -15,8 +15,8 @@ from nni.algorithms.compression.pytorch.pruning import LevelPruner, SlimPruner, ...@@ -15,8 +15,8 @@ from nni.algorithms.compression.pytorch.pruning import LevelPruner, SlimPruner,
TaylorFOWeightFilterPruner, NetAdaptPruner, SimulatedAnnealingPruner, ADMMPruner, \ TaylorFOWeightFilterPruner, NetAdaptPruner, SimulatedAnnealingPruner, ADMMPruner, \
AutoCompressPruner, AMCPruner AutoCompressPruner, AMCPruner
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from models.pytorch_models.mobilenet import MobileNet from sdk.models.pytorch_models.mobilenet import MobileNet
def validate_sparsity(wrapper, sparsity, bias=False): def validate_sparsity(wrapper, sparsity, bias=False):
masks = [wrapper.weight_mask] masks = [wrapper.weight_mask]
......
...@@ -13,8 +13,8 @@ from unittest import TestCase, main ...@@ -13,8 +13,8 @@ from unittest import TestCase, main
from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner from nni.algorithms.compression.pytorch.pruning import TransformerHeadPruner
sys.path.append(os.path.dirname(__file__)) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from models.pytorch_models.transformer import TransformerEncoder from sdk.models.pytorch_models.transformer import TransformerEncoder
def validate_sparsity(wrapper, sparsity, bias=False): def validate_sparsity(wrapper, sparsity, bias=False):
......
...@@ -41,6 +41,7 @@ class TorchModel(torch.nn.Module): ...@@ -41,6 +41,7 @@ class TorchModel(torch.nn.Module):
def trainer(model, optimizer, criterion): def trainer(model, optimizer, criterion):
model.train() model.train()
for _ in range(10):
input = torch.rand(10, 1, 28, 28) input = torch.rand(10, 1, 28, 28)
label = torch.Tensor(list(range(10))).type(torch.LongTensor) label = torch.Tensor(list(range(10))).type(torch.LongTensor)
optimizer.zero_grad() optimizer.zero_grad()
...@@ -65,7 +66,7 @@ class IterativePrunerTestCase(unittest.TestCase): ...@@ -65,7 +66,7 @@ class IterativePrunerTestCase(unittest.TestCase):
def test_linear_pruner(self): def test_linear_pruner(self):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = LinearPruner(model, config_list, 'level', 3, log_dir='../../logs') pruner = LinearPruner(model, config_list, 'level', 3, log_dir='../../../logs')
pruner.compress() pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result() _, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list) sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
...@@ -74,7 +75,7 @@ class IterativePrunerTestCase(unittest.TestCase): ...@@ -74,7 +75,7 @@ class IterativePrunerTestCase(unittest.TestCase):
def test_agp_pruner(self): def test_agp_pruner(self):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = AGPPruner(model, config_list, 'level', 3, log_dir='../../logs') pruner = AGPPruner(model, config_list, 'level', 3, log_dir='../../../logs')
pruner.compress() pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result() _, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list) sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
...@@ -83,7 +84,7 @@ class IterativePrunerTestCase(unittest.TestCase): ...@@ -83,7 +84,7 @@ class IterativePrunerTestCase(unittest.TestCase):
def test_lottery_ticket_pruner(self): def test_lottery_ticket_pruner(self):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = LotteryTicketPruner(model, config_list, 'level', 3, log_dir='../../logs') pruner = LotteryTicketPruner(model, config_list, 'level', 3, log_dir='../../../logs')
pruner.compress() pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result() _, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list) sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
...@@ -92,7 +93,7 @@ class IterativePrunerTestCase(unittest.TestCase): ...@@ -92,7 +93,7 @@ class IterativePrunerTestCase(unittest.TestCase):
def test_simulated_annealing_pruner(self): def test_simulated_annealing_pruner(self):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = SimulatedAnnealingPruner(model, config_list, evaluator, start_temperature=40, log_dir='../../logs') pruner = SimulatedAnnealingPruner(model, config_list, evaluator, start_temperature=40, log_dir='../../../logs')
pruner.compress() pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result() _, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list) sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
...@@ -112,7 +113,7 @@ class IterativePrunerTestCase(unittest.TestCase): ...@@ -112,7 +113,7 @@ class IterativePrunerTestCase(unittest.TestCase):
'evaluator': evaluator, 'evaluator': evaluator,
'start_temperature': 40 'start_temperature': 40
} }
pruner = AutoCompressPruner(model, config_list, 10, admm_params, sa_params=sa_params, log_dir='../../logs') pruner = AutoCompressPruner(model, config_list, 10, admm_params, sa_params=sa_params, log_dir='../../../logs')
pruner.compress() pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result() _, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list) sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
......
...@@ -44,6 +44,7 @@ class TorchModel(torch.nn.Module): ...@@ -44,6 +44,7 @@ class TorchModel(torch.nn.Module):
def trainer(model, optimizer, criterion): def trainer(model, optimizer, criterion):
model.train() model.train()
for _ in range(10):
input = torch.rand(10, 1, 28, 28) input = torch.rand(10, 1, 28, 28)
label = torch.Tensor(list(range(10))).type(torch.LongTensor) label = torch.Tensor(list(range(10))).type(torch.LongTensor)
optimizer.zero_grad() optimizer.zero_grad()
......
...@@ -50,6 +50,7 @@ class TorchModel(torch.nn.Module): ...@@ -50,6 +50,7 @@ class TorchModel(torch.nn.Module):
def trainer(model, optimizer, criterion): def trainer(model, optimizer, criterion):
model.train() model.train()
for _ in range(10):
input = torch.rand(10, 1, 28, 28) input = torch.rand(10, 1, 28, 28)
label = torch.Tensor(list(range(10))).type(torch.LongTensor) label = torch.Tensor(list(range(10))).type(torch.LongTensor)
optimizer.zero_grad() optimizer.zero_grad()
......
...@@ -7,7 +7,7 @@ import unittest ...@@ -7,7 +7,7 @@ import unittest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult from nni.algorithms.compression.v2.pytorch.base import TaskResult
from nni.algorithms.compression.v2.pytorch.pruning.tools import ( from nni.algorithms.compression.v2.pytorch.pruning.tools import (
AGPTaskGenerator, AGPTaskGenerator,
LinearTaskGenerator, LinearTaskGenerator,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment