Unverified Commit 04ae3dee authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] add channel pruning mode for admm pruner & optimize movement...

[Compression] add channel pruning mode for admm pruner & optimize movement pruning performance (#4691)
parent 21e3f27e
...@@ -15,6 +15,7 @@ from torchvision import datasets, transforms ...@@ -15,6 +15,7 @@ from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
import nni import nni
from nni.compression.pytorch.speedup import ModelSpeedup
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import ADMMPruner from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import ADMMPruner
...@@ -108,18 +109,17 @@ if __name__ == '__main__': ...@@ -108,18 +109,17 @@ if __name__ == '__main__':
config_list = [{ config_list = [{
'sparsity': 0.8, 'sparsity': 0.8,
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}, {
'sparsity': 0.92,
'op_types': ['Conv2d'],
}] }]
# make sure you have used nni.trace to wrap the optimizer class before initialize # make sure you have used nni.trace to wrap the optimizer class before initialize
traced_optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) traced_optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = ADMMPruner(model, config_list, trainer, traced_optimizer, criterion, iterations=2, training_epochs=2) pruner = ADMMPruner(model, config_list, trainer, traced_optimizer, criterion, iterations=10, training_epochs=1, granularity='coarse-grained')
_, masks = pruner.compress() _, masks = pruner.compress()
pruner.show_pruned_weights() pruner.show_pruned_weights()
# Fine-grained method does not need to speedup pruner._unwrap_model()
ModelSpeedup(model, torch.randn([128, 3, 32, 32]).to(device), masks).speedup_model()
print('\n' + '=' * 50 + ' EVALUATE THE MODEL AFTER PRUNING ' + '=' * 50) print('\n' + '=' * 50 + ' EVALUATE THE MODEL AFTER PRUNING ' + '=' * 50)
evaluator(model) evaluator(model)
......
import functools import functools
import time
from tqdm import tqdm from tqdm import tqdm
import torch import torch
...@@ -31,7 +32,7 @@ task_to_keys = { ...@@ -31,7 +32,7 @@ task_to_keys = {
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gradient_accumulation_steps = 16 gradient_accumulation_steps = 8
# a fake criterion because huggingface output already has loss # a fake criterion because huggingface output already has loss
def criterion(input, target): def criterion(input, target):
...@@ -40,7 +41,7 @@ def criterion(input, target): ...@@ -40,7 +41,7 @@ def criterion(input, target):
def trainer(model, optimizer, criterion, train_dataloader): def trainer(model, optimizer, criterion, train_dataloader):
model.train() model.train()
counter = 0 counter = 0
for batch in tqdm(train_dataloader): for batch in (train_dataloader):
counter += 1 counter += 1
batch.to(device) batch.to(device)
optimizer.zero_grad() optimizer.zero_grad()
...@@ -51,12 +52,14 @@ def trainer(model, optimizer, criterion, train_dataloader): ...@@ -51,12 +52,14 @@ def trainer(model, optimizer, criterion, train_dataloader):
loss.backward() loss.backward()
if counter % gradient_accumulation_steps == 0 or counter == len(train_dataloader): if counter % gradient_accumulation_steps == 0 or counter == len(train_dataloader):
optimizer.step() optimizer.step()
if counter % 16000 == 0: if counter % 800 == 0:
print('[{}]: {}'.format(time.asctime(time.localtime(time.time())), counter))
if counter % 8000 == 0:
print('Step {}: {}'.format(counter // gradient_accumulation_steps, evaluator(model, metric, is_regression, validate_dataloader))) print('Step {}: {}'.format(counter // gradient_accumulation_steps, evaluator(model, metric, is_regression, validate_dataloader)))
def evaluator(model, metric, is_regression, eval_dataloader): def evaluator(model, metric, is_regression, eval_dataloader):
model.eval() model.eval()
for batch in tqdm(eval_dataloader): for batch in (eval_dataloader):
batch.to(device) batch.to(device)
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze() predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
...@@ -70,8 +73,8 @@ if __name__ == '__main__': ...@@ -70,8 +73,8 @@ if __name__ == '__main__':
task_name = 'mnli' task_name = 'mnli'
is_regression = False is_regression = False
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2) num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
train_batch_size = 8 train_batch_size = 4
eval_batch_size = 8 eval_batch_size = 4
set_seed(1024) set_seed(1024)
...@@ -113,7 +116,7 @@ if __name__ == '__main__': ...@@ -113,7 +116,7 @@ if __name__ == '__main__':
# make sure you have used nni.trace to wrap the optimizer class before initialize # make sure you have used nni.trace to wrap the optimizer class before initialize
traced_optimizer = nni.trace(Adam)(model.parameters(), lr=2e-5) traced_optimizer = nni.trace(Adam)(model.parameters(), lr=2e-5)
pruner = MovementPruner(model, config_list, p_trainer, traced_optimizer, criterion, training_epochs=10, pruner = MovementPruner(model, config_list, p_trainer, traced_optimizer, criterion, training_epochs=10,
warm_up_step=3000, cool_down_beginning_step=27000) warm_up_step=12272, cool_down_beginning_step=110448)
_, masks = pruner.compress() _, masks = pruner.compress()
pruner.show_pruned_weights() pruner.show_pruned_weights()
......
...@@ -158,6 +158,21 @@ class AutoCompressPruner(IterativePruner): ...@@ -158,6 +158,21 @@ class AutoCompressPruner(IterativePruner):
keep_intermediate_result=keep_intermediate_result) keep_intermediate_result=keep_intermediate_result)
if 'traced_optimizer' in admm_params: if 'traced_optimizer' in admm_params:
admm_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, admm_params['traced_optimizer']) admm_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, admm_params['traced_optimizer'])
# granularity in ADMM stage will align with SA stage, if 'granularity' is not specify
if 'granularity' not in admm_params:
# only if level pruning and fine-grained admm pruning used in SA, fine-grained admm pruning will used in auto-compress
if 'pruning_algorithm' in sa_params:
sa_algo = sa_params['pruning_algorithm']
sa_algo_params = sa_params.get('pruning_params')
if sa_algo in ['level']:
admm_params['granularity'] = 'fine-grained'
elif sa_algo in ['admm'] and (sa_algo_params is not None) and not (sa_algo_params.get('granularity') == 'coarse-grained'):
admm_params['granularity'] = 'fine-grained'
else:
admm_params['granularity'] = 'coarse-grained'
else:
admm_params['granularity'] = 'fine-grained'
pruner = ADMMPruner(None, None, **admm_params) pruner = ADMMPruner(None, None, **admm_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input, super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False) evaluator=evaluator, reset_weight=False)
...@@ -1073,6 +1073,11 @@ class ADMMPruner(BasicPruner): ...@@ -1073,6 +1073,11 @@ class ADMMPruner(BasicPruner):
The total iteration number in admm pruning algorithm. The total iteration number in admm pruning algorithm.
training_epochs : int training_epochs : int
The epoch number for training model in each iteration. The epoch number for training model in each iteration.
granularity : str
'fine-grained' or 'coarse-grained'.
If 'coarse-grained' is set, ADMM pruner will generate masks on output channels wise.
In original admm pruning paper, author implemented a fine-grained admm pruning.
In auto-compress paper, author used coarse-grained admm pruning.
Examples Examples
-------- --------
...@@ -1091,7 +1096,8 @@ class ADMMPruner(BasicPruner): ...@@ -1091,7 +1096,8 @@ class ADMMPruner(BasicPruner):
""" """
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int, training_epochs: int): traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int,
training_epochs: int, granularity: str = 'fine-grained'):
self.trainer = trainer self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper): if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer self.optimizer_helper = traced_optimizer
...@@ -1100,6 +1106,8 @@ class ADMMPruner(BasicPruner): ...@@ -1100,6 +1106,8 @@ class ADMMPruner(BasicPruner):
self.criterion = criterion self.criterion = criterion
self.iterations = iterations self.iterations = iterations
self.training_epochs = training_epochs self.training_epochs = training_epochs
assert granularity in ['fine-grained', 'coarse-grained']
self.granularity = granularity
super().__init__(model, config_list) super().__init__(model, config_list)
def reset(self, model: Optional[Module], config_list: Optional[List[Dict]]): def reset(self, model: Optional[Module], config_list: Optional[List[Dict]]):
...@@ -1131,9 +1139,15 @@ class ADMMPruner(BasicPruner): ...@@ -1131,9 +1139,15 @@ class ADMMPruner(BasicPruner):
else: else:
self.data_collector.reset() self.data_collector.reset()
if self.metrics_calculator is None: if self.metrics_calculator is None:
self.metrics_calculator = NormMetricsCalculator() if self.granularity == 'fine-grained':
self.metrics_calculator = NormMetricsCalculator(p=1)
elif self.granularity == 'coarse-grained':
self.metrics_calculator = NormMetricsCalculator(dim=0, p=1)
if self.sparsity_allocator is None: if self.sparsity_allocator is None:
if self.granularity == 'fine-grained':
self.sparsity_allocator = NormalSparsityAllocator(self) self.sparsity_allocator = NormalSparsityAllocator(self)
elif self.granularity == 'coarse-grained':
self.sparsity_allocator = NormalSparsityAllocator(self, dim=0)
def compress(self) -> Tuple[Module, Dict]: def compress(self) -> Tuple[Module, Dict]:
""" """
......
...@@ -46,7 +46,8 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper): ...@@ -46,7 +46,8 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
def forward(self, *inputs): def forward(self, *inputs):
# apply mask to weight, bias # apply mask to weight, bias
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask)) # NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach_()`
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach_()))
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias = torch.mul(self.bias, self.bias_mask) self.module.bias = torch.mul(self.bias, self.bias_mask)
return self.module(*inputs) return self.module(*inputs)
...@@ -75,7 +76,7 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector): ...@@ -75,7 +76,7 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
data = {} data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items(): for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight_score.data.clone().detach() data[wrapper.name] = wrapper.weight_score.data
return data return data
......
...@@ -19,7 +19,8 @@ class StraightMetricsCalculator(MetricsCalculator): ...@@ -19,7 +19,8 @@ class StraightMetricsCalculator(MetricsCalculator):
def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]: def calculate_metrics(self, data: Dict[str, Tensor]) -> Dict[str, Tensor]:
metrics = {} metrics = {}
for name, tensor in data.items(): for name, tensor in data.items():
metrics[name] = tensor.clone().detach() # use inplace detach `detach_` here to avoid creating a new tensor
metrics[name] = tensor.clone().detach_()
return metrics return metrics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment