Unverified Commit ef4459e4 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression v2] Auto-Compress Pruner (#4280)

parent cb6c72ea
......@@ -26,6 +26,7 @@ and how to schedule sparsity in each iteration are implemented as iterative prun
* `AGP Pruner <#agp-pruner>`__
* `Lottery Ticket Pruner <#lottery-ticket-pruner>`__
* `Simulated Annealing Pruner <#simulated-annealing-pruner>`__
* `Auto Compress Pruner <#auto-compress-pruner>`__
Level Pruner
------------
......@@ -397,3 +398,45 @@ User configuration for Simulated Annealing Pruner
**PyTorch**
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.SimulatedAnnealingPruner
Auto Compress Pruner
--------------------
For total iteration number :math:`N`, AutoCompressPruner prune the model that survive the previous iteration for a fixed sparsity ratio (e.g., :math:`1-{(1-0.8)}^{(1/N)}`) to achieve the overall sparsity (e.g., :math:`0.8`):
.. code-block:: bash
1. Generate sparsities distribution using SimulatedAnnealingPruner
2. Perform ADMM-based pruning to generate pruning result for the next iteration.
For more details, please refer to `AutoCompress: An Automatic DNN Structured Pruning Framework for Ultra-High Compression Rates <https://arxiv.org/abs/1907.03141>`__.
Usage
^^^^^^
.. code-block:: python
from nni.algorithms.compression.v2.pytorch.pruning import AutoCompressPruner
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
admm_params = {
'trainer': trainer,
'optimizer': optimizer,
'criterion': criterion,
'iterations': 10,
'training_epochs': 1
}
sa_params = {
'evaluator': evaluator
}
pruner = AutoCompressPruner(model, config_list, 10, admm_params, sa_params, finetuner=finetuner)
pruner.compress()
_, model, masks, _, _ = pruner.get_best_result()
The full script can be found :githublink:`here <examples/model_compress/pruning/v2/auto_compress_pruner.py>`.
User configuration for Auto Compress Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**PyTorch**
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.AutoCompressPruner
from tqdm import tqdm
import torch
from torchvision import datasets, transforms
from nni.algorithms.compression.v2.pytorch.pruning import AutoCompressPruner
from examples.model_compress.models.cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=128, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
epoch = 0
def trainer(model, optimizer, criterion):
global epoch
model.train()
for data, target in tqdm(iterable=train_loader, desc='Total Epoch {}'.format(epoch)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
epoch = epoch + 1
def finetuner(model):
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
trainer(model, optimizer, criterion)
def evaluator(model):
model.eval()
correct = 0
with torch.no_grad():
for data, target in tqdm(iterable=test_loader, desc='Test'):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100 * correct / len(test_loader.dataset)
print('Accuracy: {}%\n'.format(acc))
return acc
if __name__ == '__main__':
model = VGG().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
# pre-train the model
for _ in range(10):
trainer(model, optimizer, criterion)
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
dummy_input = torch.rand(10, 3, 32, 32).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
admm_params = {
'trainer': trainer,
'optimizer': optimizer,
'criterion': criterion,
'iterations': 10,
'training_epochs': 1
}
sa_params = {
'evaluator': evaluator
}
pruner = AutoCompressPruner(model, config_list, 10, admm_params, sa_params, keep_intermediate_result=True, finetuner=finetuner)
pruner.compress()
_, model, masks, _, _ = pruner.get_best_result()
from .basic_pruner import *
from .basic_scheduler import PruningScheduler
from .iterative_pruner import *
from .auto_compress_pruner import AutoCompressPruner
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
from typing import Dict, List, Callable, Optional
from torch import Tensor
from torch.nn import Module
from .basic_pruner import ADMMPruner
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator
class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, sa_params: Dict = {}, log_dir: str = '.',
keep_intermediate_result: bool = False):
self.iterative_pruner = SimulatedAnnealingPruner(model=None,
config_list=None,
log_dir=Path(log_dir, 'SA'),
**sa_params)
super().__init__(total_iteration=total_iteration,
origin_model=origin_model,
origin_config_list=origin_config_list,
origin_masks=origin_masks,
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
def _iterative_pruner_reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.iterative_pruner.task_generator._log_dir = Path(self._log_dir_root, 'SA')
self.iterative_pruner.reset(model, config_list=config_list, masks=masks)
def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]):
self._iterative_pruner_reset(model, new_config_list, masks)
self.iterative_pruner.compress()
_, _, _, _, config_list = self.iterative_pruner.get_best_result()
return config_list
class AutoCompressPruner(IterativePruner):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user.
total_iteration : int
The total iteration number.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
admm_params : Dict
The parameters passed to the ADMMPruner.
- trainer : Callable[[Module, Optimizer, Callable].
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs.
- optimizer : torch.optim.Optimizer.
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data,
so do not use this optimizer in other places.
- criterion : Callable[[Tensor, Tensor], Tensor].
The criterion function used in trainer. Take model output and target value as input, and return the loss.
- iterations : int.
The total iteration number in admm pruning algorithm.
- training_epochs : int.
The epoch number for training model in each iteration.
sa_params : Dict
The parameters passed to the SimulatedAnnealingPruner.
- evaluator : Callable[[Module], float]. Required.
Evaluate the pruned model and give a score.
- start_temperature : float. Default: `100`.
Start temperature of the simulated annealing process.
- stop_temperature : float. Default: `20`.
Stop temperature of the simulated annealing process.
- cool_down_rate : float. Default: `0.9`.
Cooldown rate of the temperature.
- perturbation_magnitude : float. Default: `0.35`.
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
- pruning_algorithm : str. Default: `'level'`.
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
- pruning_params : Dict. Default: `{}`.
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
log_dir : str
The log directory used to save the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handles all finetune logic, takes a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speed_up : bool
If set True, speed up the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for tracing the model in speed up.
"""
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False,
dummy_input: Optional[Tensor] = None, evaluator: Callable[[Module], float] = None):
task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
sa_params=sa_params,
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
pruner = ADMMPruner(None, None, **admm_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
......@@ -720,12 +720,15 @@ class ADMMPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int, training_epochs: int):
self.trainer = trainer
# TODO: handle optimizer here will case additional memory use, need improve, also in WeightTrainerBasedDataCollector
self.optimizer = optimizer
self.criterion = criterion
self.iterations = iterations
self.training_epochs = training_epochs
super().__init__(model, config_list)
def reset(self, model: Optional[Module], config_list: Optional[List[Dict]]):
super().reset(model, config_list)
self.Z = {name: wrapper.module.weight.data.clone().detach() for name, wrapper in self.get_modules_wrapper().items()}
self.U = {name: torch.zeros_like(z).to(z.device) for name, z in self.Z.items()}
......@@ -777,6 +780,10 @@ class ADMMPruner(BasicPruner):
self.Z[name] = self.Z[name].mul(mask['weight'])
self.U[name] = self.U[name] + data[name] - self.Z[name]
self.Z = None
self.U = None
torch.cuda.empty_cache()
metrics = self.metrics_calculator.calculate_metrics(data)
masks = self.sparsity_allocator.generate_sparsity(metrics)
......
......@@ -4,6 +4,7 @@
from copy import deepcopy
from typing import Dict, List, Tuple, Callable, Optional
import torch
from torch import Tensor
from torch.nn import Module
......@@ -24,10 +25,11 @@ class PruningScheduler(BasePruningScheduler):
Used to generate task for each iteration.
finetuner
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise.
speed_up
If set True, speed up the model in each iteration.
If set True, speed up the model at the end of each iteration to make the pruned model compact.
dummy_input
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
If `speed_up` is True, `dummy_input` is required for tracing the model in speed up.
evaluator
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
......@@ -45,6 +47,9 @@ class PruningScheduler(BasePruningScheduler):
self.evaluator = evaluator
self.reset_weight = reset_weight
def reset(self, model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}):
self.task_generator.reset(model, config_list, masks)
def generate_task(self) -> Optional[Task]:
return self.task_generator.next()
......@@ -144,9 +149,11 @@ class PruningScheduler(BasePruningScheduler):
def pruning_one_step(self, task: Task) -> TaskResult:
if self.reset_weight:
return self.pruning_one_step_reset_weight(task)
result = self.pruning_one_step_reset_weight(task)
else:
return self.pruning_one_step_normal(task)
result = self.pruning_one_step_normal(task)
torch.cuda.empty_cache()
return result
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]:
return self.task_generator.get_best_result()
......@@ -19,7 +19,7 @@ from .basic_pruner import (
ADMMPruner
)
from .basic_scheduler import PruningScheduler
from .tools.task_generator import (
from .tools import (
LinearTaskGenerator,
AGPTaskGenerator,
LotteryTicketTaskGenerator,
......@@ -74,8 +74,7 @@ class LinearPruner(IterativePruner):
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
The origin config list provided by the user.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
......@@ -86,22 +85,23 @@ class LinearPruner(IterativePruner):
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speed_up : bool
If set True, speed up the model in each iteration.
If set True, speed up the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
If `speed_up` is True, `dummy_input` is required for tracing the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
pruning_params : Dict
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: dict = {}):
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}):
task_generator = LinearTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
......@@ -119,8 +119,7 @@ class AGPPruner(IterativePruner):
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
The origin config list provided by the user.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
......@@ -131,22 +130,23 @@ class AGPPruner(IterativePruner):
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration, usually for neutralizing the accuracy loss brought by the pruning in this iteration.
speed_up : bool
If set True, speed up the model in each iteration.
If set True, speed up the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
If `speed_up` is True, `dummy_input` is required for tracing the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
pruning_params : Dict
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: dict = {}):
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: Dict = {}):
task_generator = AGPTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
......@@ -164,8 +164,7 @@ class LotteryTicketPruner(IterativePruner):
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
The origin config list provided by the user.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
......@@ -176,25 +175,26 @@ class LotteryTicketPruner(IterativePruner):
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
The finetuner handled all finetune logic, use a pytorch module as input.
It will be called at the end of each iteration if reset_weight is False, will be called at the beginning of each iteration otherwise.
speed_up : bool
If set True, speed up the model in each iteration.
If set True, speed up the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
If `speed_up` is True, `dummy_input` is required for tracing the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
reset_weight : bool
If set True, the model weight will reset to the original model weight at the end of each iteration step.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
pruning_params : Dict
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, reset_weight: bool = True,
pruning_params: dict = {}):
pruning_params: Dict = {}):
task_generator = LotteryTicketTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
......@@ -212,11 +212,7 @@ class SimulatedAnnealingPruner(IterativePruner):
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
The origin config list provided by the user.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
start_temperature : float
......@@ -227,6 +223,11 @@ class SimulatedAnnealingPruner(IterativePruner):
Cool down rate of the temperature.
perturbation_magnitude : float
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
pruning_params : Dict
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
......@@ -234,18 +235,15 @@ class SimulatedAnnealingPruner(IterativePruner):
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
If set True, speed up the model at the end of each iteration to make the pruned model compact.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
If `speed_up` is True, `dummy_input` is required for tracing the model in speed up.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, evaluator: Callable[[Module], float],
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
pruning_params: dict = {}):
def __init__(self, model: Module, config_list: List[Dict], evaluator: Callable[[Module], float], start_temperature: float = 100,
stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35,
pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None):
task_generator = SimulatedAnnealingTaskGenerator(origin_model=model,
origin_config_list=config_list,
start_temperature=start_temperature,
......
......@@ -450,10 +450,7 @@ class SparsityAllocator:
class TaskGenerator:
"""
This class used to generate config list for pruner in each iteration.
"""
def __init__(self, origin_model: Module, origin_masks: Dict[str, Dict[str, Tensor]] = {},
origin_config_list: List[Dict] = [], log_dir: str = '.', keep_intermediate_result: bool = False):
"""
Parameters
----------
origin_model
......@@ -468,11 +465,20 @@ class TaskGenerator:
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
assert isinstance(origin_model, Module), 'Only support pytorch module.'
def __init__(self, origin_model: Optional[Module], origin_masks: Optional[Dict[str, Dict[str, Tensor]]] = {},
origin_config_list: Optional[List[Dict]] = [], log_dir: str = '.', keep_intermediate_result: bool = False):
self._log_dir = log_dir
self._keep_intermediate_result = keep_intermediate_result
self._log_dir_root = Path(log_dir, datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')).absolute()
if origin_model is not None and origin_config_list is not None and origin_masks is not None:
self.reset(origin_model, origin_config_list, origin_masks)
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
assert isinstance(model, Module), 'Only support pytorch module.'
self._log_dir_root = Path(self._log_dir, datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')).absolute()
self._log_dir_root.mkdir(parents=True, exist_ok=True)
self._keep_intermediate_result = keep_intermediate_result
self._intermediate_result_dir = Path(self._log_dir_root, 'intermediate_result')
self._intermediate_result_dir.mkdir(parents=True, exist_ok=True)
......@@ -480,7 +486,7 @@ class TaskGenerator:
self._origin_model_path = Path(self._log_dir_root, 'origin', 'model.pth')
self._origin_masks_path = Path(self._log_dir_root, 'origin', 'masks.pth')
self._origin_config_list_path = Path(self._log_dir_root, 'origin', 'config_list.json')
self._save_data('origin', origin_model, origin_masks, origin_config_list)
self._save_data('origin', model, masks, config_list)
self._task_id_candidate = 0
self._tasks: Dict[int, Task] = {}
......
......@@ -43,13 +43,15 @@ class FunctionBasedTaskGenerator(TaskGenerator):
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
self.current_iteration = 0
self.target_sparsity = config_list_canonical(origin_model, origin_config_list)
self.total_iteration = total_iteration
super().__init__(origin_model, origin_config_list=self.target_sparsity, origin_masks=origin_masks,
super().__init__(origin_model, origin_config_list=origin_config_list, origin_masks=origin_masks,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_iteration = 0
self.target_sparsity = config_list_canonical(model, config_list)
super().reset(model, config_list=config_list, masks=masks)
def init_pending_tasks(self) -> List[Task]:
origin_model = torch.load(self._origin_model_path)
origin_masks = torch.load(self._origin_masks_path)
......@@ -81,6 +83,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
task_id = self._task_id_candidate
new_config_list = self.generate_config_list(self.target_sparsity, self.current_iteration, compact2origin_sparsity)
new_config_list = self.allocate_sparsity(new_config_list, compact_model, compact_model_masks)
config_list_path = Path(self._intermediate_result_dir, '{}_config_list.json'.format(task_id))
with Path(config_list_path).open('w') as f:
......@@ -97,6 +100,9 @@ class FunctionBasedTaskGenerator(TaskGenerator):
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
raise NotImplementedError()
def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]):
return new_config_list
class AGPTaskGenerator(FunctionBasedTaskGenerator):
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
......@@ -123,11 +129,10 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator):
class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermediate_result: bool = False):
super().__init__(total_iteration, origin_model, origin_config_list, origin_masks=origin_masks, log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_iteration = 1
self.target_sparsity = config_list_canonical(model, config_list)
super(FunctionBasedTaskGenerator, self).reset(model, config_list=config_list, masks=masks)
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
config_list = []
......@@ -172,21 +177,25 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
self.start_temperature = start_temperature
self.current_temperature = start_temperature
self.stop_temperature = stop_temperature
self.cool_down_rate = cool_down_rate
self.perturbation_magnitude = perturbation_magnitude
self.weights_numel, self.masked_rate = get_model_weights_numel(origin_model, origin_config_list, origin_masks)
self.target_sparsity_list = config_list_canonical(origin_model, origin_config_list)
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
def reset(self, model: Module, config_list: List[Dict] = [], masks: Dict[str, Dict[str, Tensor]] = {}):
self.current_temperature = self.start_temperature
self.weights_numel, self.masked_rate = get_model_weights_numel(model, config_list, masks)
self.target_sparsity_list = config_list_canonical(model, config_list)
self._adjust_target_sparsity()
self._temp_config_list = None
self._current_sparsity_list = None
self._current_score = None
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
super().reset(model, config_list=config_list, masks=masks)
def _adjust_target_sparsity(self):
"""
......@@ -199,9 +208,10 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
pruned_weight_numel = 0
for name in op_names:
remaining_weight_numel += self.weights_numel[name]
if name in self.masked_rate:
if name in self.masked_rate and self.masked_rate[name] != 0:
pruned_weight_numel += 1 / (1 / self.masked_rate[name] - 1) * self.weights_numel[name]
config['total_sparsity'] = max(0, sparsity - pruned_weight_numel / (pruned_weight_numel + remaining_weight_numel))
total_mask_rate = pruned_weight_numel / (pruned_weight_numel + remaining_weight_numel)
config['total_sparsity'] = max(0, (sparsity - total_mask_rate) / (1 - total_mask_rate))
def _init_temp_config_list(self):
self._temp_config_list = []
......
......@@ -11,7 +11,8 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
LinearPruner,
AGPPruner,
LotteryTicketPruner,
SimulatedAnnealingPruner
SimulatedAnnealingPruner,
AutoCompressPruner
)
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact
......@@ -38,6 +39,24 @@ class TorchModel(torch.nn.Module):
return F.log_softmax(x, dim=1)
def trainer(model, optimizer, criterion):
model.train()
input = torch.rand(10, 1, 28, 28)
label = torch.Tensor(list(range(10))).type(torch.LongTensor)
optimizer.zero_grad()
output = model(input)
loss = criterion(output, label)
loss.backward()
optimizer.step()
def get_optimizer(model):
return torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
def evaluator(model):
return random.random()
......@@ -50,7 +69,7 @@ class IterativePrunerTestCase(unittest.TestCase):
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_agp_pruner(self):
model = TorchModel()
......@@ -59,7 +78,7 @@ class IterativePrunerTestCase(unittest.TestCase):
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_lottery_ticket_pruner(self):
model = TorchModel()
......@@ -68,16 +87,37 @@ class IterativePrunerTestCase(unittest.TestCase):
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_simulated_annealing_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = SimulatedAnnealingPruner(model, config_list, 'level', evaluator, start_temperature=30, log_dir='../../logs')
pruner = SimulatedAnnealingPruner(model, config_list, evaluator, start_temperature=40, log_dir='../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_auto_compress_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
admm_params = {
'trainer': trainer,
'optimizer': get_optimizer(model),
'criterion': criterion,
'iterations': 10,
'training_epochs': 1
}
sa_params = {
'evaluator': evaluator,
'start_temperature': 40
}
pruner = AutoCompressPruner(model, config_list, 10, admm_params, sa_params=sa_params, log_dir='../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
print(sparsity_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
if __name__ == '__main__':
unittest.main()
......@@ -57,7 +57,7 @@ def run_task_generator(task_generator_type):
elif task_generator_type == 'linear':
assert count == 6
elif task_generator_type == 'lottery_ticket':
assert count == 6
assert count == 5
elif task_generator_type == 'simulated_annealing':
assert count == 17
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment