Unverified Commit e98ebcf0 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Model Compression] Pruning Scheduler (#4089)

parent 04f439a0
from .compressor import Compressor, LayerInfo from .compressor import Compressor, LayerInfo
from .pruner import Pruner, PrunerModuleWrapper from .pruner import Pruner, PrunerModuleWrapper
from .scheduler import BasePruningScheduler, Task, TaskResult
...@@ -84,6 +84,17 @@ class Compressor: ...@@ -84,6 +84,17 @@ class Compressor:
self._wrap_model() self._wrap_model()
def clear_model_references(self):
"""
Clear all references to the model in this compressor. Just to free up memory.
Need reset first before the next time call compressor function.
"""
self._unwrap_model()
self.bound_model = None
self.config_list = None
self.modules_wrapper = None
self._modules_to_compress = None
def _detect_modules_to_compress(self) -> List[Tuple[LayerInfo, Dict]]: def _detect_modules_to_compress(self) -> List[Tuple[LayerInfo, Dict]]:
""" """
Detect all modules should be compressed, and save the result in `self._modules_to_compress`. Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
......
...@@ -87,14 +87,17 @@ class Pruner(Compressor): ...@@ -87,14 +87,17 @@ class Pruner(Compressor):
Parameters Parameters
---------- ----------
masks masks
The masks dict with format {'op_name': {'weight_mask': mask, 'bias_mask': mask}}. The masks dict with format {'op_name': {'weight': mask, 'bias': mask}}.
""" """
wrappers = self.get_modules_wrapper() wrappers = self.get_modules_wrapper()
for name, layer_mask in masks.items(): for name, layer_mask in masks.items():
assert name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(name) assert name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(name)
for mask_type, mask in layer_mask.items(): if layer_mask.get('weight') is not None:
assert hasattr(wrappers[name], mask_type), 'there is no attribute {} in wrapper'.format(mask_type) assert hasattr(wrappers[name], 'weight_mask'), 'There is no attribute weight_mask in wrapper.'
setattr(wrappers[name], mask_type, mask) setattr(wrappers[name], 'weight_mask', layer_mask.get('weight'))
if layer_mask.get('bias') is not None:
assert hasattr(wrappers[name], 'bias_mask'), 'There is no attribute bias_mask in wrapper.'
setattr(wrappers[name], 'bias_mask', layer_mask.get('bias'))
def compress(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]]]: def compress(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]]]:
""" """
...@@ -126,27 +129,21 @@ class Pruner(Compressor): ...@@ -126,27 +129,21 @@ class Pruner(Compressor):
index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0, as_tuple=False).tolist() index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0, as_tuple=False).tolist()
_logger.info(f'simulated prune {wrapper.name} remain/total: {len(index)}/{weight_mask.size(dim)}') _logger.info(f'simulated prune {wrapper.name} remain/total: {len(index)}/{weight_mask.size(dim)}')
def export_model(self, model_path, mask_path=None, onnx_path=None, input_shape=None, device=None): def export_model(self, model_path: str, mask_path: Optional[str] = None):
""" """
Export pruned model weights, masks and onnx model(optional) Export pruned model weights, masks and onnx model(optional)
Parameters Parameters
---------- ----------
model_path model_path
Path to save pruned model state_dict. Path to save pruned model state_dict. The weight and bias have already multiplied the masks.
mask_path mask_path
(optional) path to save mask dict. Path to save mask dict.
onnx_path
(optional) path to save onnx model.
input_shape
Input shape to onnx model.
device
Device of the model, used to place the dummy input tensor for exporting onnx file.
The tensor is placed on cpu if ```device``` is None.
""" """
assert model_path is not None, 'model_path must be specified' assert self.bound_model is not None, 'The bound model reference has been cleared.'
assert model_path is not None, 'model_path must be specified.'
mask_dict = {} mask_dict = {}
self._unwrap_model() # used for generating correct state_dict name without wrapper state self._unwrap_model()
for name, wrapper in self.get_modules_wrapper().items(): for name, wrapper in self.get_modules_wrapper().items():
weight_mask = wrapper.weight_mask weight_mask = wrapper.weight_mask
...@@ -159,20 +156,13 @@ class Pruner(Compressor): ...@@ -159,20 +156,13 @@ class Pruner(Compressor):
if bias_mask is not None: if bias_mask is not None:
wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask) wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask)
# save mask to dict # save mask to dict
mask_dict[name] = {"weight_mask": weight_mask, "bias_mask": bias_mask} mask_dict[name] = {"weight": weight_mask, "bias": bias_mask}
torch.save(self.bound_model.state_dict(), model_path) torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path) _logger.info('Model state_dict saved to %s', model_path)
if mask_path is not None: if mask_path is not None:
torch.save(mask_dict, mask_path) torch.save(mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path) _logger.info('Mask dict saved to %s', mask_path)
if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed
if device is None:
device = torch.device('cpu')
input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data.to(device), onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
self._wrap_model() self._wrap_model()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import gc
import logging
import os
from pathlib import Path
from typing import List, Dict, Tuple, Literal, Optional
import json_tricks
import torch
from torch.nn import Module
from torch.tensor import Tensor
_logger = logging.getLogger(__name__)
class Task:
# NOTE: If we want to support multi-thread, this part need to refactor, maybe use file and lock to sync.
_reference_counter = {}
def __init__(self, task_id: int, model_path: str, masks_path: str, config_list_path: str) -> None:
"""
Parameters
----------
task_id
The unique id of task.
model_path
The path of the unwrapped pytorch model that will be pruned in this task.
masks_path
The path of the masks that applied on the model before pruning.
config_list_path
The path of the config list that used in this task.
"""
self.task_id = task_id
self.model_path = model_path
self.masks_path = masks_path
self.config_list_path = config_list_path
self.status: Literal['Pending', 'Running', 'Finished'] = 'Pending'
self.score: Optional[float] = None
self.state = {}
for ref in self.referenced_paths():
self._reference_counter.setdefault(ref, 0)
self._reference_counter[ref] += 1
self._cleaned = False
def to_dict(self) -> Dict:
return {
'task_id': self.task_id,
'model_path': str(self.model_path),
'masks_path': str(self.masks_path),
'config_list_path': str(self.config_list_path),
'status': self.status,
'score': self.score,
'state': self.state
}
def load_data(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]], List[Dict]]:
"""
Returns
-------
Tuple[Module, Dict[str, Dict[str, Tensor]], List[Dict]]
Return the model pruning in this task, the masks of the model before pruning,
the config list used in this task.
"""
model = torch.load(self.model_path)
masks = torch.load(self.masks_path)
with Path(self.config_list_path).open('r') as f:
config_list = json_tricks.load(f)
return model, masks, config_list
def referenced_paths(self) -> List[str]:
"""
Return the path list that need to count reference in this task.
"""
return [self.model_path, self.masks_path, self.config_list_path]
def clean_up(self):
"""
Counter of referenced file paths subtract 1. If the counter reach 0, then delete the file.
"""
if not self._cleaned:
for ref in self.referenced_paths():
self._reference_counter[ref] -= 1
if self._reference_counter[ref] <= 0:
os.remove(ref)
if self._reference_counter[ref] < 0:
_logger.warning('Referance counter error, the number of %s is %d',
ref, self._reference_counter[ref])
self._cleaned = True
else:
_logger.warning('Already clean up task %d', self.task_id)
class TaskResult:
def __init__(self, task_id: int, compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]],
pruner_generated_masks: Dict[str, Dict[str, Tensor]], score: Optional[float]) -> None:
"""
Parameters
----------
task_id
The unique id of task.
compact_model
The unwrapped compact pytorch model after pruning. If the compact model has been speeduped during the pruning process,
it will have a smaller structure compare with the model before pruning.
If the compact model has not been speeduped, it will have the same structure with the model before pruning.
compact_model_masks
The masks on the compact model. If the compact model has been speeduped during the pruning process,
the `compact_model_masks` is always an empty dict. If the compact model has not been speeduped,
the `compact_model_masks` is same as `pruner_generated_masks`.
pruner_generated_masks
The masks that can apply on the before pruning model. It is always the output of `pruner.compress()`.
TODO: If the compact model has been speeduped, the auto infer masks maybe also need.
score
The score of the pruning effect. i.e., the accuracy or latency after pruning.
"""
self.task_id = task_id
self.compact_model = compact_model
self.compact_model_masks = compact_model_masks
self.pruner_generated_masks = pruner_generated_masks
self.score = score
class BasePruningScheduler:
def generate_task(self) -> Optional[Task]:
"""
Returns
-------
Optional[Task]
Return the next pruning task.
"""
raise NotImplementedError()
def record_task_result(self, task_result: TaskResult):
"""
Parameters
----------
task_result
The result of the task
"""
raise NotImplementedError()
def pruning_one_step(self, task: Task) -> TaskResult:
"""
Pruning the model defined in task.
Parameters
----------
task
The pruning task in this step.
Returns
-------
TaskResult
Return the result of the task in this step.
"""
raise NotImplementedError()
def get_best_result(self) -> Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]:
"""
Returns
-------
Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]
Return the task result that has the best performance,
inculde task id, the compact model, the masks on the compact model, score and config list used in this task.
"""
raise NotImplementedError()
def compress(self):
"""
The pruning schedule main loop.
"""
task = self.generate_task()
while task is not None:
task_result = self.pruning_one_step(task)
self.record_task_result(task_result)
del task_result
gc.collect()
task = self.generate_task()
...@@ -72,7 +72,7 @@ INTERNAL_SCHEMA = { ...@@ -72,7 +72,7 @@ INTERNAL_SCHEMA = {
} }
class OneShotPruner(Pruner): class BasicPruner(Pruner):
def __init__(self, model: Module, config_list: List[Dict]): def __init__(self, model: Module, config_list: List[Dict]):
self.data_collector: DataCollector = None self.data_collector: DataCollector = None
self.metrics_calculator: MetricsCalculator = None self.metrics_calculator: MetricsCalculator = None
...@@ -120,7 +120,7 @@ class OneShotPruner(Pruner): ...@@ -120,7 +120,7 @@ class OneShotPruner(Pruner):
return self.bound_model, masks return self.bound_model, masks
class LevelPruner(OneShotPruner): class LevelPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict]): def __init__(self, model: Module, config_list: List[Dict]):
""" """
Parameters Parameters
...@@ -154,7 +154,7 @@ class LevelPruner(OneShotPruner): ...@@ -154,7 +154,7 @@ class LevelPruner(OneShotPruner):
self.sparsity_allocator = NormalSparsityAllocator(self) self.sparsity_allocator = NormalSparsityAllocator(self)
class NormPruner(OneShotPruner): class NormPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict], p: int, def __init__(self, model: Module, config_list: List[Dict], p: int,
mode: str = 'normal', dummy_input: Optional[Tensor] = None): mode: str = 'normal', dummy_input: Optional[Tensor] = None):
""" """
...@@ -275,7 +275,7 @@ class L2NormPruner(NormPruner): ...@@ -275,7 +275,7 @@ class L2NormPruner(NormPruner):
super().__init__(model, config_list, 2, mode, dummy_input) super().__init__(model, config_list, 2, mode, dummy_input)
class FPGMPruner(OneShotPruner): class FPGMPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict], def __init__(self, model: Module, config_list: List[Dict],
mode: str = 'normal', dummy_input: Optional[Tensor] = None): mode: str = 'normal', dummy_input: Optional[Tensor] = None):
""" """
...@@ -331,7 +331,7 @@ class FPGMPruner(OneShotPruner): ...@@ -331,7 +331,7 @@ class FPGMPruner(OneShotPruner):
raise NotImplementedError('Only support mode `normal` and `dependency_aware`') raise NotImplementedError('Only support mode `normal` and `dependency_aware`')
class SlimPruner(OneShotPruner): class SlimPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor],
training_epochs: int, scale: float = 0.0001, mode='global'): training_epochs: int, scale: float = 0.0001, mode='global'):
...@@ -427,7 +427,7 @@ class SlimPruner(OneShotPruner): ...@@ -427,7 +427,7 @@ class SlimPruner(OneShotPruner):
raise NotImplementedError('Only support mode `normal` and `global`') raise NotImplementedError('Only support mode `normal` and `global`')
class ActivationPruner(OneShotPruner): class ActivationPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, activation: str = 'relu', optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, activation: str = 'relu',
mode: str = 'normal', dummy_input: Optional[Tensor] = None): mode: str = 'normal', dummy_input: Optional[Tensor] = None):
...@@ -544,7 +544,7 @@ class ActivationMeanRankPruner(ActivationPruner): ...@@ -544,7 +544,7 @@ class ActivationMeanRankPruner(ActivationPruner):
return MeanRankMetricsCalculator(dim=1) return MeanRankMetricsCalculator(dim=1)
class TaylorFOWeightPruner(OneShotPruner): class TaylorFOWeightPruner(BasicPruner):
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int,
mode: str = 'normal', dummy_input: Optional[Tensor] = None): mode: str = 'normal', dummy_input: Optional[Tensor] = None):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from copy import deepcopy
from typing import Dict, List, Tuple, Callable, Optional
from torch import Tensor
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningScheduler, Task, TaskResult
from nni.compression.pytorch.speedup import ModelSpeedup
from .tools import TaskGenerator
class PruningScheduler(BasePruningScheduler):
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None,
speed_up: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None):
"""
Parameters
----------
pruner
The pruner used in pruner scheduler.
The scheduler will use `Pruner.reset(model, config_list)` to reset it in each iteration.
task_generator
Used to generate task for each iteration.
finetuner
The finetuner handled all finetune logic, use a pytorch module as input.
speed_up
If set True, speed up the model in each iteration.
dummy_input
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
evaluator
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
"""
self.pruner = pruner
self.task_generator = task_generator
self.finetuner = finetuner
self.speed_up = speed_up
self.dummy_input = dummy_input
self.evaluator = evaluator
def generate_task(self) -> Optional[Task]:
return self.task_generator.next()
def record_task_result(self, task_result: TaskResult):
self.task_generator.receive_task_result(task_result)
def pruning_one_step(self, task: Task) -> TaskResult:
model, masks, config_list = task.load_data()
# pruning model
self.pruner.reset(model, config_list)
self.pruner.load_masks(masks)
compact_model, pruner_generated_masks = self.pruner.compress()
compact_model_masks = deepcopy(pruner_generated_masks)
# show the pruning effect
self.pruner.show_pruned_weights()
self.pruner._unwrap_model()
# speed up
if self.speed_up:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# finetune
if self.finetuner is not None:
if self.speed_up:
self.finetuner(compact_model)
else:
self.pruner._wrap_model()
self.finetuner(compact_model)
self.pruner._unwrap_model()
# evaluate
score = self.evaluator(compact_model) if self.evaluator is not None else None
# clear model references
self.pruner.clear_model_references()
return TaskResult(task.task_id, compact_model, compact_model_masks, pruner_generated_masks, score)
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]:
return self.task_generator.get_best_result()
...@@ -2,7 +2,8 @@ from .base import ( ...@@ -2,7 +2,8 @@ from .base import (
HookCollectorInfo, HookCollectorInfo,
DataCollector, DataCollector,
MetricsCalculator, MetricsCalculator,
SparsityAllocator SparsityAllocator,
TaskGenerator
) )
from .data_collector import ( from .data_collector import (
WeightDataCollector, WeightDataCollector,
...@@ -21,3 +22,7 @@ from .sparsity_allocator import ( ...@@ -21,3 +22,7 @@ from .sparsity_allocator import (
GlobalSparsityAllocator, GlobalSparsityAllocator,
Conv2dDependencyAwareAllocator Conv2dDependencyAwareAllocator
) )
from .task_generator import (
AGPTaskGenerator,
LinearTaskGenerator
)
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from datetime import datetime
import logging import logging
from pathlib import Path
import types import types
from typing import List, Dict, Optional, Callable, Union from typing import List, Dict, Tuple, Optional, Callable, Union
import json_tricks
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base import Compressor, LayerInfo from nni.algorithms.compression.v2.pytorch.base import Compressor, LayerInfo, Task, TaskResult
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
__all__ = ['DataCollector', 'TrainerBasedDataCollector', 'HookCollectorInfo', 'MetricsCalculator', 'SparsityAllocator']
class DataCollector: class DataCollector:
""" """
...@@ -371,7 +372,7 @@ class SparsityAllocator: ...@@ -371,7 +372,7 @@ class SparsityAllocator:
Returns Returns
------- -------
Dict[str, Tensor] Dict[str, Tensor]
The key is `weight_mask` or `bias_mask`, value is the final mask. The key is `weight` or `bias`, value is the final mask.
""" """
weight_mask = mask.clone() weight_mask = mask.clone()
...@@ -390,7 +391,7 @@ class SparsityAllocator: ...@@ -390,7 +391,7 @@ class SparsityAllocator:
if self.dim is None: if self.dim is None:
assert weight_mask.size() == weight_size assert weight_mask.size() == weight_size
expand_mask = {'weight_mask': weight_mask} expand_mask = {'weight': weight_mask}
else: else:
# expand mask to weight size with dim # expand mask to weight size with dim
assert len(weight_mask.size()) == len(self.dim) assert len(weight_mask.size()) == len(self.dim)
...@@ -400,15 +401,19 @@ class SparsityAllocator: ...@@ -400,15 +401,19 @@ class SparsityAllocator:
[idxs.pop(i) for i in reversed(self.dim)] [idxs.pop(i) for i in reversed(self.dim)]
for i in idxs: for i in idxs:
weight_mask = weight_mask.unsqueeze(i) weight_mask = weight_mask.unsqueeze(i)
expand_mask = {'weight_mask': weight_mask.expand(weight_size).clone()} expand_mask = {'weight': weight_mask.expand(weight_size).clone()}
# NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence. # NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence.
# If we support more kind of masks, this place need refactor. # If we support more kind of masks, this place need refactor.
if wrapper.bias_mask is not None and weight_mask.size() == wrapper.bias_mask.size(): if wrapper.bias_mask is not None and weight_mask.size() == wrapper.bias_mask.size():
expand_mask['bias_mask'] = weight_mask.clone() expand_mask['bias'] = weight_mask.clone()
return expand_mask return expand_mask
def _compress_mask(self, mask: Tensor) -> Tensor: def _compress_mask(self, mask: Tensor) -> Tensor:
""" """
This function will reduce the mask with `self.dim` and `self.block_sparse_size`.
e.g., a mask tensor with size [50, 60, 70], self.dim is (0, 1), self.block_sparse_size is [10, 10].
Then, the reduced mask size is [50 / 10, 60 / 10] => [5, 6].
Parameters Parameters
---------- ----------
name name
...@@ -419,7 +424,7 @@ class SparsityAllocator: ...@@ -419,7 +424,7 @@ class SparsityAllocator:
Returns Returns
------- -------
Tensor Tensor
Reduce the mask with `self.dim` and `self.block_sparse_size`. Reduced mask.
""" """
if self.dim is None or len(mask.size()) == 1: if self.dim is None or len(mask.size()) == 1:
mask = mask.clone() mask = mask.clone()
...@@ -440,3 +445,131 @@ class SparsityAllocator: ...@@ -440,3 +445,131 @@ class SparsityAllocator:
mask = torch.einsum(ein_expression, mask, torch.ones(self.block_sparse_size).to(mask.device)) mask = torch.einsum(ein_expression, mask, torch.ones(self.block_sparse_size).to(mask.device))
return (mask != 0).type_as(mask) return (mask != 0).type_as(mask)
class TaskGenerator:
"""
This class used to generate config list for pruner in each iteration.
"""
def __init__(self, origin_model: Module, origin_masks: Dict[str, Dict[str, Tensor]] = {},
origin_config_list: List[Dict] = [], log_dir: str = '.', keep_intermidiate_result: bool = False):
"""
Parameters
----------
origin_model
The origin unwrapped pytorch model to be pruned.
origin_masks
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
origin_config_list
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
log_dir
The log directory use to saving the task generator log.
keep_intermidiate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
assert isinstance(origin_model, Module), 'Only support pytorch module.'
self._log_dir_root = Path(log_dir, datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')).absolute()
self._log_dir_root.mkdir(parents=True, exist_ok=True)
self._keep_intermidiate_result = keep_intermidiate_result
self._intermidiate_result_dir = Path(self._log_dir_root, 'intermidiate_result')
self._intermidiate_result_dir.mkdir(parents=True, exist_ok=True)
# save origin data in {log_dir}/origin
self._origin_model_path = Path(self._log_dir_root, 'origin', 'model.pth')
self._origin_masks_path = Path(self._log_dir_root, 'origin', 'masks.pth')
self._origin_config_list_path = Path(self._log_dir_root, 'origin', 'config_list.json')
self._save_data('origin', origin_model, origin_masks, origin_config_list)
self._task_id_candidate = 0
self._tasks: Dict[int, Task] = {}
self._pending_tasks: List[Task] = self.init_pending_tasks()
self._best_score = None
self._best_task_id = None
# dump self._tasks into {log_dir}/.tasks
self._dump_tasks_info()
def _dump_tasks_info(self):
tasks = {task_id: task.to_dict() for task_id, task in self._tasks.items()}
with Path(self._log_dir_root, '.tasks').open('w') as f:
json_tricks.dump(tasks, f, indent=4)
def _save_data(self, folder_name: str, model: Module, masks: Dict[str, Dict[str, Tensor]], config_list: List[Dict]):
Path(self._log_dir_root, folder_name).mkdir(parents=True, exist_ok=True)
torch.save(model, Path(self._log_dir_root, folder_name, 'model.pth'))
torch.save(masks, Path(self._log_dir_root, folder_name, 'masks.pth'))
with Path(self._log_dir_root, folder_name, 'config_list.json').open('w') as f:
json_tricks.dump(config_list, f, indent=4)
def update_best_result(self, task_result: TaskResult):
score = task_result.score
if score is not None:
task_id = task_result.task_id
task = self._tasks[task_id]
task.score = score
if self._best_score is None or score > self._best_score:
self._best_score = score
self._best_task_id = task_id
with Path(task.config_list_path).open('r') as fr:
best_config_list = json_tricks.load(fr)
self._save_data('best_result', task_result.compact_model, task_result.compact_model_masks, best_config_list)
def init_pending_tasks(self) -> List[Task]:
raise NotImplementedError()
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
raise NotImplementedError()
def receive_task_result(self, task_result: TaskResult):
"""
Parameters
----------
task_result
The result of the task.
"""
task_id = task_result.task_id
assert task_id in self._tasks, 'Task {} does not exist.'.format(task_id)
self.update_best_result(task_result)
self._tasks[task_id].status = 'Finished'
self._dump_tasks_info()
self._pending_tasks.extend(self.generate_tasks(task_result))
self._dump_tasks_info()
if not self._keep_intermidiate_result:
self._tasks[task_id].clean_up()
def next(self) -> Optional[Task]:
"""
Returns
-------
Optional[Task]
Return the next task from pending tasks.
"""
if len(self._pending_tasks) == 0:
return None
else:
task = self._pending_tasks.pop(0)
task.status = 'Running'
self._dump_tasks_info()
return task
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]:
"""
Returns
-------
Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]
If self._best_task_id is not None,
return best task id, best compact model, masks on the compact model, score, config list used in this task.
"""
if self._best_task_id is not None:
compact_model = torch.load(Path(self._log_dir_root, 'best_result', 'best_model.pth'))
compact_model_masks = torch.load(Path(self._log_dir_root, 'best_result', 'best_masks.pth'))
with Path(self._log_dir_root, 'best_result', 'best_config_list.json').open('r') as f:
config_list = json_tricks.load(f)
return self._best_task_id, compact_model, compact_model_masks, self._best_score, config_list
return None
...@@ -27,8 +27,9 @@ class NormalSparsityAllocator(SparsityAllocator): ...@@ -27,8 +27,9 @@ class NormalSparsityAllocator(SparsityAllocator):
metric = metrics[name] * self._compress_mask(wrapper.weight_mask) metric = metrics[name] * self._compress_mask(wrapper.weight_mask)
prune_num = int(sparsity_rate * metric.numel()) prune_num = int(sparsity_rate * metric.numel())
if prune_num == 0: if prune_num == 0:
continue threshold = metric.min() - 1
threshold = torch.topk(metric.view(-1), prune_num, largest=False)[0].max() else:
threshold = torch.topk(metric.view(-1), prune_num, largest=False)[0].max()
mask = torch.gt(metric, threshold).type_as(metric) mask = torch.gt(metric, threshold).type_as(metric)
masks[name] = self._expand_mask(name, mask) masks[name] = self._expand_mask(name, mask)
return masks return masks
...@@ -65,19 +66,22 @@ class GlobalSparsityAllocator(SparsityAllocator): ...@@ -65,19 +66,22 @@ class GlobalSparsityAllocator(SparsityAllocator):
wrapper = self.pruner.get_modules_wrapper()[name] wrapper = self.pruner.get_modules_wrapper()[name]
metric = metric * self._compress_mask(wrapper.weight_mask) metric = metric * self._compress_mask(wrapper.weight_mask)
layer_weight_num = wrapper.module.weight.data.numel() layer_weight_num = wrapper.module.weight.data.numel()
total_weight_num += layer_weight_num
expend_times = int(layer_weight_num / metric.numel())
retention_ratio = 1 - max_sparsity_per_layer.get(name, 1) retention_ratio = 1 - max_sparsity_per_layer.get(name, 1)
retention_numel = math.ceil(retention_ratio * layer_weight_num) retention_numel = math.ceil(retention_ratio * layer_weight_num)
removed_metric_num = math.ceil(retention_numel / (wrapper.weight_mask.numel() / metric.numel())) removed_metric_num = math.ceil(retention_numel / (wrapper.weight_mask.numel() / metric.numel()))
stay_metric_num = metric.numel() - removed_metric_num stay_metric_num = metric.numel() - removed_metric_num
if stay_metric_num <= 0:
sub_thresholds[name] = metric.min().item() - 1
continue
# Remove the weight parts that must be left # Remove the weight parts that must be left
stay_metric = torch.topk(metric.view(-1), stay_metric_num, largest=False)[0] stay_metric = torch.topk(metric.view(-1), stay_metric_num, largest=False)[0]
sub_thresholds[name] = stay_metric.max() sub_thresholds[name] = stay_metric.max()
expend_times = int(layer_weight_num / metric.numel())
if expend_times > 1: if expend_times > 1:
stay_metric = stay_metric.expand(stay_metric_num, int(layer_weight_num / metric.numel())).view(-1) stay_metric = stay_metric.expand(stay_metric_num, int(layer_weight_num / metric.numel())).view(-1)
metric_list.append(stay_metric) metric_list.append(stay_metric)
total_weight_num += layer_weight_num
total_prune_num = int(total_sparsity * total_weight_num) total_prune_num = int(total_sparsity * total_weight_num)
if total_prune_num == 0: if total_prune_num == 0:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from copy import deepcopy
import logging
from pathlib import Path
from typing import Dict, List
import json_tricks
from torch import Tensor
import torch
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils.pruning import config_list_canonical, compute_sparsity
from .base import TaskGenerator
_logger = logging.getLogger(__name__)
class FunctionBasedTaskGenerator(TaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False):
self.current_iteration = 0
self.target_sparsity = config_list_canonical(origin_model, origin_config_list)
self.total_iteration = total_iteration
super().__init__(origin_model, origin_config_list=self.target_sparsity, origin_masks=origin_masks,
log_dir=log_dir, keep_intermidiate_result=keep_intermidiate_result)
def init_pending_tasks(self) -> List[Task]:
origin_model = torch.load(self._origin_model_path)
origin_masks = torch.load(self._origin_masks_path)
task_result = TaskResult('origin', origin_model, origin_masks, origin_masks, None)
return self.generate_tasks(task_result)
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
compact_model = task_result.compact_model
compact_model_masks = task_result.compact_model_masks
# save intermidiate result
model_path = Path(self._intermidiate_result_dir, '{}_compact_model.pth'.format(task_result.task_id))
masks_path = Path(self._intermidiate_result_dir, '{}_compact_model_masks.pth'.format(task_result.task_id))
torch.save(compact_model, model_path)
torch.save(compact_model_masks, masks_path)
# get current2origin_sparsity and compact2origin_sparsity
origin_model = torch.load(self._origin_model_path)
current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.target_sparsity)
_logger.info('\nTask %s total real sparsity compared with original model is:\n%s', str(task_result.task_id), json_tricks.dumps(current2origin_sparsity, indent=4))
if task_result.task_id != 'origin':
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
# if reach the total_iteration, no more task will be generated
if self.current_iteration >= self.total_iteration:
return []
task_id = self._task_id_candidate
new_config_list = self.generate_config_list(self.target_sparsity, self.current_iteration, compact2origin_sparsity)
config_list_path = Path(self._intermidiate_result_dir, '{}_config_list.json'.format(task_id))
with Path(config_list_path).open('w') as f:
json_tricks.dump(new_config_list, f, indent=4)
task = Task(task_id, model_path, masks_path, config_list_path)
self._tasks[task_id] = task
self._task_id_candidate += 1
self.current_iteration += 1
return [task]
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
raise NotImplementedError()
class AGPTaskGenerator(FunctionBasedTaskGenerator):
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, model_based_sparsity: List[Dict]) -> List[Dict]:
config_list = []
for target, mo in zip(target_sparsity, model_based_sparsity):
ori_sparsity = (1 - (1 - iteration / self.total_iteration) ** 3) * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
class LinearTaskGenerator(FunctionBasedTaskGenerator):
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, model_based_sparsity: List[Dict]) -> List[Dict]:
config_list = []
for target, mo in zip(target_sparsity, model_based_sparsity):
ori_sparsity = iteration / self.total_iteration * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from copy import deepcopy from copy import deepcopy
from typing import Dict, List from typing import Dict, List, Tuple
import torch
from torch import Tensor
from torch.nn import Module from torch.nn import Module
...@@ -82,3 +84,113 @@ def dedupe_config_list(config_list: List[Dict]) -> List[Dict]: ...@@ -82,3 +84,113 @@ def dedupe_config_list(config_list: List[Dict]) -> List[Dict]:
for idx in sorted(exclude_idxes, reverse=True): for idx in sorted(exclude_idxes, reverse=True):
config_list.pop(idx) config_list.pop(idx)
return config_list return config_list
def compute_sparsity_compact2origin(origin_model: Module, compact_model: Module, config_list: List[Dict]) -> List[Dict]:
"""
Compare origin model and compact model, return the sparsity of each group mentioned in config list.
A group means all layer mentioned in one config.
e.g., a linear named 'linear1' and its weight size is [100, 100] in origin model, but in compact model,
the layer weight size with same layer name is [100, 50],
then this function will return [{'op_names': 'linear1', 'total_sparsity': 0.5}].
"""
compact2origin_sparsity = []
for config in config_list:
left_weight_num = 0
total_weight_num = 0
for module_name, module in origin_model.named_modules():
module_type = type(module).__name__
if 'op_types' in config and module_type not in config['op_types']:
continue
if 'op_names' in config and module_name not in config['op_names']:
continue
total_weight_num += module.weight.data.numel()
for module_name, module in compact_model.named_modules():
module_type = type(module).__name__
if 'op_types' in config and module_type not in config['op_types']:
continue
if 'op_names' in config and module_name not in config['op_names']:
continue
left_weight_num += module.weight.data.numel()
compact2origin_sparsity.append(deepcopy(config))
compact2origin_sparsity[-1]['total_sparsity'] = 1 - left_weight_num / total_weight_num
return compact2origin_sparsity
def compute_sparsity_mask2compact(compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]], config_list: List[Dict]):
"""
Apply masks on compact model, return the sparsity of each group mentioned in config list.
A group means all layer mentioned in one config.
This function count all zero elements of the masks in one group,
then divide by the elements number of the weights in this group to compute sparsity.
"""
mask2compact_sparsity = []
for config in config_list:
left_weight_num = 0
total_weight_num = 0
for module_name, module in compact_model.named_modules():
module_type = type(module).__name__
if 'op_types' in config and module_type not in config['op_types']:
continue
if 'op_names' in config and module_name not in config['op_names']:
continue
module_weight_num = module.weight.data.numel()
total_weight_num += module_weight_num
if module_name in compact_model_masks:
weight_mask = compact_model_masks[module_name]['weight']
left_weight_num += len(torch.nonzero(weight_mask, as_tuple=False))
else:
left_weight_num += module_weight_num
mask2compact_sparsity.append(deepcopy(config))
mask2compact_sparsity[-1]['total_sparsity'] = 1 - left_weight_num / total_weight_num
return mask2compact_sparsity
def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]],
config_list: List[Dict]) -> Tuple[List[Dict], List[Dict], List[Dict]]:
"""
This function computes how much the origin model has been compressed in the current state.
The current state means `compact_model` + `compact_model_masks`
(i.e., `compact_model_masks` applied on `compact_model`).
The compact model is the origin model after pruning,
and it may have different structure with origin_model cause of speed up.
Returns
-------
Tuple[List[Dict], List[Dict], List[Dict]]
(current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity).
current2origin_sparsity is how much the origin model has been compressed in the current state.
compact2origin_sparsity is the sparsity obtained by comparing the structure of origin model and compact model.
mask2compact_sparsity is the sparsity computed by count the zero value in the mask.
"""
compact2origin_sparsity = compute_sparsity_compact2origin(origin_model, compact_model, config_list)
mask2compact_sparsity = compute_sparsity_mask2compact(compact_model, compact_model_masks, config_list)
assert len(compact2origin_sparsity) == len(mask2compact_sparsity), 'Length mismatch.'
current2origin_sparsity = []
for c2o_sparsity, m2c_sparsity, config in zip(compact2origin_sparsity, mask2compact_sparsity, config_list):
current2origin_sparsity.append(deepcopy(config))
current2origin_sparsity[-1]['total_sparsity'] = 1 - (1 - c2o_sparsity['total_sparsity']) * (1 - m2c_sparsity['total_sparsity'])
return current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity
def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}) -> Dict:
"""
Count the layer weight elements number in config_list.
If masks is not empty, the masked weight will not be counted.
"""
model_weights_numel = {}
masked_rate = {}
for config in config_list:
for module_name, module in model.named_modules():
module_type = type(module).__name__
if 'op_types' in config and module_type not in config['op_types']:
continue
if 'op_names' in config and module_name not in config['op_names']:
continue
if module_name in masks and isinstance(masks[module_name]['weight'], Tensor):
weight_mask = masks[module_name]['weight']
masked_rate[module_name] = 1 - (weight_mask.sum().item() / weight_mask.numel())
model_weights_numel[module_name] = round(weight_mask.sum().item())
else:
model_weights_numel[module_name] = module.weight.data.numel()
return model_weights_numel, masked_rate
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment