Unverified Commit a16e570d authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Model Compression] Add more Task Generator (#4178)

parent 7a50c96d
...@@ -15,7 +15,8 @@ from .tools import TaskGenerator ...@@ -15,7 +15,8 @@ from .tools import TaskGenerator
class PruningScheduler(BasePruningScheduler): class PruningScheduler(BasePruningScheduler):
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None, def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None,
speed_up: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None): speed_up: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None,
reset_weight: bool = False):
""" """
Parameters Parameters
---------- ----------
...@@ -33,6 +34,8 @@ class PruningScheduler(BasePruningScheduler): ...@@ -33,6 +34,8 @@ class PruningScheduler(BasePruningScheduler):
evaluator evaluator
Evaluate the pruned model and give a score. Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result. If evaluator is None, the best result refers to the latest result.
reset_weight
If set True, the model weight will reset to the origin model weight at the end of each iteration step.
""" """
self.pruner = pruner self.pruner = pruner
self.task_generator = task_generator self.task_generator = task_generator
...@@ -40,6 +43,7 @@ class PruningScheduler(BasePruningScheduler): ...@@ -40,6 +43,7 @@ class PruningScheduler(BasePruningScheduler):
self.speed_up = speed_up self.speed_up = speed_up
self.dummy_input = dummy_input self.dummy_input = dummy_input
self.evaluator = evaluator self.evaluator = evaluator
self.reset_weight = reset_weight
def generate_task(self) -> Optional[Task]: def generate_task(self) -> Optional[Task]:
return self.task_generator.next() return self.task_generator.next()
...@@ -47,12 +51,15 @@ class PruningScheduler(BasePruningScheduler): ...@@ -47,12 +51,15 @@ class PruningScheduler(BasePruningScheduler):
def record_task_result(self, task_result: TaskResult): def record_task_result(self, task_result: TaskResult):
self.task_generator.receive_task_result(task_result) self.task_generator.receive_task_result(task_result)
def pruning_one_step(self, task: Task) -> TaskResult: def pruning_one_step_normal(self, task: Task) -> TaskResult:
"""
generate masks -> speed up -> finetune -> evaluate
"""
model, masks, config_list = task.load_data() model, masks, config_list = task.load_data()
# pruning model
self.pruner.reset(model, config_list) self.pruner.reset(model, config_list)
self.pruner.load_masks(masks) self.pruner.load_masks(masks)
# pruning model
compact_model, pruner_generated_masks = self.pruner.compress() compact_model, pruner_generated_masks = self.pruner.compress()
compact_model_masks = deepcopy(pruner_generated_masks) compact_model_masks = deepcopy(pruner_generated_masks)
...@@ -75,12 +82,71 @@ class PruningScheduler(BasePruningScheduler): ...@@ -75,12 +82,71 @@ class PruningScheduler(BasePruningScheduler):
self.pruner._unwrap_model() self.pruner._unwrap_model()
# evaluate # evaluate
score = self.evaluator(compact_model) if self.evaluator is not None else None if self.evaluator is not None:
if self.speed_up:
score = self.evaluator(compact_model)
else:
self.pruner._wrap_model()
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
else:
score = None
# clear model references # clear model references
self.pruner.clear_model_references() self.pruner.clear_model_references()
return TaskResult(task.task_id, compact_model, compact_model_masks, pruner_generated_masks, score) return TaskResult(task.task_id, compact_model, compact_model_masks, pruner_generated_masks, score)
def pruning_one_step_reset_weight(self, task: Task) -> TaskResult:
"""
finetune -> generate masks -> reset weight -> speed up -> evaluate
"""
model, masks, config_list = task.load_data()
checkpoint = deepcopy(model.state_dict())
self.pruner.reset(model, config_list)
self.pruner.load_masks(masks)
# finetune
if self.finetuner is not None:
self.finetuner(model)
# pruning model
compact_model, pruner_generated_masks = self.pruner.compress()
compact_model_masks = deepcopy(pruner_generated_masks)
# show the pruning effect
self.pruner.show_pruned_weights()
self.pruner._unwrap_model()
# reset model weight
compact_model.load_state_dict(checkpoint)
# speed up
if self.speed_up:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# evaluate
if self.evaluator is not None:
if self.speed_up:
score = self.evaluator(compact_model)
else:
self.pruner._wrap_model()
score = self.evaluator(compact_model)
self.pruner._unwrap_model()
else:
score = None
# clear model references
self.pruner.clear_model_references()
return TaskResult(task.task_id, compact_model, compact_model_masks, pruner_generated_masks, score)
def pruning_one_step(self, task: Task) -> TaskResult:
if self.reset_weight:
return self.pruning_one_step_reset_weight(task)
else:
return self.pruning_one_step_normal(task)
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]: def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]:
return self.task_generator.get_best_result() return self.task_generator.get_best_result()
...@@ -24,5 +24,7 @@ from .sparsity_allocator import ( ...@@ -24,5 +24,7 @@ from .sparsity_allocator import (
) )
from .task_generator import ( from .task_generator import (
AGPTaskGenerator, AGPTaskGenerator,
LinearTaskGenerator LinearTaskGenerator,
LotteryTicketTaskGenerator,
SimulatedAnnealingTaskGenerator
) )
...@@ -4,15 +4,20 @@ ...@@ -4,15 +4,20 @@
from copy import deepcopy from copy import deepcopy
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List, Tuple
import json_tricks import json_tricks
import numpy as np
from torch import Tensor from torch import Tensor
import torch import torch
from torch.nn import Module from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils.pruning import config_list_canonical, compute_sparsity from nni.algorithms.compression.v2.pytorch.utils.pruning import (
config_list_canonical,
compute_sparsity,
get_model_weights_numel
)
from .base import TaskGenerator from .base import TaskGenerator
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -21,6 +26,23 @@ _logger = logging.getLogger(__name__) ...@@ -21,6 +26,23 @@ _logger = logging.getLogger(__name__)
class FunctionBasedTaskGenerator(TaskGenerator): class FunctionBasedTaskGenerator(TaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict], def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False): origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False):
"""
Parameters
----------
total_iteration
The total iteration number.
origin_model
The origin unwrapped pytorch model to be pruned.
origin_config_list
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
origin_masks
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
log_dir
The log directory use to saving the task generator log.
keep_intermidiate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
self.current_iteration = 0 self.current_iteration = 0
self.target_sparsity = config_list_canonical(origin_model, origin_config_list) self.target_sparsity = config_list_canonical(origin_model, origin_config_list)
self.total_iteration = total_iteration self.total_iteration = total_iteration
...@@ -54,7 +76,7 @@ class FunctionBasedTaskGenerator(TaskGenerator): ...@@ -54,7 +76,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
# if reach the total_iteration, no more task will be generated # if reach the total_iteration, no more task will be generated
if self.current_iteration >= self.total_iteration: if self.current_iteration > self.total_iteration:
return [] return []
task_id = self._task_id_candidate task_id = self._task_id_candidate
...@@ -77,9 +99,9 @@ class FunctionBasedTaskGenerator(TaskGenerator): ...@@ -77,9 +99,9 @@ class FunctionBasedTaskGenerator(TaskGenerator):
class AGPTaskGenerator(FunctionBasedTaskGenerator): class AGPTaskGenerator(FunctionBasedTaskGenerator):
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, model_based_sparsity: List[Dict]) -> List[Dict]: def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
config_list = [] config_list = []
for target, mo in zip(target_sparsity, model_based_sparsity): for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = (1 - (1 - iteration / self.total_iteration) ** 3) * target['total_sparsity'] ori_sparsity = (1 - (1 - iteration / self.total_iteration) ** 3) * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity'])) sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity']) assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
...@@ -89,12 +111,223 @@ class AGPTaskGenerator(FunctionBasedTaskGenerator): ...@@ -89,12 +111,223 @@ class AGPTaskGenerator(FunctionBasedTaskGenerator):
class LinearTaskGenerator(FunctionBasedTaskGenerator): class LinearTaskGenerator(FunctionBasedTaskGenerator):
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, model_based_sparsity: List[Dict]) -> List[Dict]: def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
config_list = [] config_list = []
for target, mo in zip(target_sparsity, model_based_sparsity): for target, mo in zip(target_sparsity, compact2origin_sparsity):
ori_sparsity = iteration / self.total_iteration * target['total_sparsity'] ori_sparsity = iteration / self.total_iteration * target['total_sparsity']
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity'])) sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity']) assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
config_list.append(deepcopy(target)) config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity config_list[-1]['total_sparsity'] = sparsity
return config_list return config_list
class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False):
super().__init__(total_iteration, origin_model, origin_config_list, origin_masks=origin_masks, log_dir=log_dir,
keep_intermidiate_result=keep_intermidiate_result)
self.current_iteration = 1
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
config_list = []
for target, mo in zip(target_sparsity, compact2origin_sparsity):
# NOTE: The ori_sparsity calculation formula in compression v1 is as follow, it is different from the paper.
# But the formula in paper will cause numerical problems, so keep the formula in compression v1.
ori_sparsity = 1 - (1 - target['total_sparsity']) ** (iteration / self.total_iteration)
# The following is the formula in paper.
# ori_sparsity = (target['total_sparsity'] * 100) ** (iteration / self.total_iteration) / 100
sparsity = max(0.0, (ori_sparsity - mo['total_sparsity']) / (1 - mo['total_sparsity']))
assert 0 <= sparsity <= 1, 'sparsity: {}, ori_sparsity: {}, model_sparsity: {}'.format(sparsity, ori_sparsity, mo['total_sparsity'])
config_list.append(deepcopy(target))
config_list[-1]['total_sparsity'] = sparsity
return config_list
class SimulatedAnnealingTaskGenerator(TaskGenerator):
def __init__(self, origin_model: Module, origin_config_list: List[Dict], origin_masks: Dict[str, Dict[str, Tensor]] = {},
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermidiate_result: bool = False):
"""
Parameters
----------
origin_model
The origin unwrapped pytorch model to be pruned.
origin_config_list
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
origin_masks
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
start_temperature
Start temperature of the simulated annealing process.
stop_temperature
Stop temperature of the simulated annealing process.
cool_down_rate
Cool down rate of the temperature.
perturbation_magnitude
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
log_dir
The log directory use to saving the task generator log.
keep_intermidiate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
self.start_temperature = start_temperature
self.current_temperature = start_temperature
self.stop_temperature = stop_temperature
self.cool_down_rate = cool_down_rate
self.perturbation_magnitude = perturbation_magnitude
self.weights_numel, self.masked_rate = get_model_weights_numel(origin_model, origin_config_list, origin_masks)
self.target_sparsity_list = config_list_canonical(origin_model, origin_config_list)
self._adjust_target_sparsity()
self._temp_config_list = None
self._current_sparsity_list = None
self._current_score = None
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermidiate_result=keep_intermidiate_result)
def _adjust_target_sparsity(self):
"""
If origin_masks is not empty, then re-scale the target sparsity.
"""
if len(self.masked_rate) > 0:
for config in self.target_sparsity_list:
sparsity, op_names = config['total_sparsity'], config['op_names']
remaining_weight_numel = 0
pruned_weight_numel = 0
for name in op_names:
remaining_weight_numel += self.weights_numel[name]
if name in self.masked_rate:
pruned_weight_numel += 1 / (1 / self.masked_rate[name] - 1) * self.weights_numel[name]
config['total_sparsity'] = max(0, sparsity - pruned_weight_numel / (pruned_weight_numel + remaining_weight_numel))
def _init_temp_config_list(self):
self._temp_config_list = []
self._temp_sparsity_list = []
for config in self.target_sparsity_list:
sparsity_config, sparsity = self._init_config_sparsity(config)
self._temp_config_list.extend(sparsity_config)
self._temp_sparsity_list.append(sparsity)
def _init_config_sparsity(self, config: Dict) -> Tuple[List[Dict], List]:
assert 'total_sparsity' in config, 'Sparsity must be set in config: {}'.format(config)
target_sparsity = config['total_sparsity']
op_names = config['op_names']
if target_sparsity == 0:
return [], []
while True:
random_sparsity = sorted(np.random.uniform(0, 1, len(op_names)))
rescaled_sparsity = self._rescale_sparsity(random_sparsity, target_sparsity, op_names)
if rescaled_sparsity is not None and rescaled_sparsity[0] >= 0 and rescaled_sparsity[-1] < 1:
break
return self._sparsity_to_config_list(rescaled_sparsity, config), rescaled_sparsity
def _rescale_sparsity(self, random_sparsity: List, target_sparsity: float, op_names: List) -> List:
assert len(random_sparsity) == len(op_names)
num_weights = sorted([self.weights_numel[op_name] for op_name in op_names])
sparsity = sorted(random_sparsity)
total_weights = 0
total_weights_pruned = 0
# calculate the scale
for idx, num_weight in enumerate(num_weights):
total_weights += num_weight
total_weights_pruned += int(num_weight * sparsity[idx])
if total_weights_pruned == 0:
return None
scale = target_sparsity / (total_weights_pruned / total_weights)
# rescale the sparsity
sparsity = np.asarray(sparsity) * scale
return sparsity
def _sparsity_to_config_list(self, sparsity: List, config: Dict) -> List[Dict]:
sparsity = sorted(sparsity)
op_names = [k for k, _ in sorted(self.weights_numel.items(), key=lambda item: item[1]) if k in config['op_names']]
assert len(sparsity) == len(op_names)
return [{'total_sparsity': sparsity, 'op_names': [op_name]} for sparsity, op_name in zip(sparsity, op_names)]
def _update_with_perturbations(self):
self._temp_config_list = []
self._temp_sparsity_list = []
# decrease magnitude with current temperature
magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude
for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list):
if len(current_sparsity) == 0:
self._temp_sparsity_list.append([])
continue
while True:
perturbation = np.random.uniform(-magnitude, magnitude, len(current_sparsity))
temp_sparsity = np.clip(0, current_sparsity + perturbation, None)
temp_sparsity = self._rescale_sparsity(temp_sparsity, config['total_sparsity'], config['op_names'])
if temp_sparsity is not None and temp_sparsity[0] >= 0 and temp_sparsity[-1] < 1:
self._temp_config_list.extend(self._sparsity_to_config_list(temp_sparsity, config))
self._temp_sparsity_list.append(temp_sparsity)
break
def _recover_real_sparsity(self, config_list: List[Dict]) -> List[Dict]:
"""
If the origin masks is not None, then the sparsity in new generated config_list need to be rescaled.
"""
for config in config_list:
assert len(config['op_names']) == 1
op_name = config['op_names'][0]
if op_name in self.masked_rate:
config['total_sparsity'] = self.masked_rate[op_name] + config['total_sparsity'] * (1 - self.masked_rate[op_name])
return config_list
def init_pending_tasks(self) -> List[Task]:
origin_model = torch.load(self._origin_model_path)
origin_masks = torch.load(self._origin_masks_path)
self.temp_model_path = Path(self._intermidiate_result_dir, 'origin_compact_model.pth')
self.temp_masks_path = Path(self._intermidiate_result_dir, 'origin_compact_model_masks.pth')
torch.save(origin_model, self.temp_model_path)
torch.save(origin_masks, self.temp_masks_path)
task_result = TaskResult('origin', origin_model, origin_masks, origin_masks, None)
return self.generate_tasks(task_result)
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
# initial/update temp config list
if self._temp_config_list is None:
self._init_temp_config_list()
else:
score = self._tasks[task_result.task_id].score
if self._current_sparsity_list is None:
self._current_sparsity_list = deepcopy(self._temp_sparsity_list)
self._current_score = score
else:
delta_E = np.abs(score - self._current_score)
probability = np.exp(-1 * delta_E / self.current_temperature)
if self._current_score < score or np.random.uniform(0, 1) < probability:
self._current_score = score
self._current_sparsity_list = deepcopy(self._temp_sparsity_list)
self.current_temperature *= self.cool_down_rate
if self.current_temperature < self.stop_temperature:
return []
self._update_with_perturbations()
task_id = self._task_id_candidate
new_config_list = self._recover_real_sparsity(deepcopy(self._temp_config_list))
config_list_path = Path(self._intermidiate_result_dir, '{}_config_list.json'.format(task_id))
with Path(config_list_path).open('w') as f:
json_tricks.dump(new_config_list, f, indent=4)
task = Task(task_id, self.temp_model_path, self.temp_masks_path, config_list_path)
self._tasks[task_id] = task
self._task_id_candidate += 1
return [task]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment