Unverified Commit d1450b4d authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix compression pipeline (#3678)

parent c05a9228
...@@ -230,7 +230,10 @@ def main(args): ...@@ -230,7 +230,10 @@ def main(args):
kw_args['optimizer'] = optimizer kw_args['optimizer'] = optimizer
kw_args['criterion'] = criterion kw_args['criterion'] = criterion
if args.pruner in ('slim', 'mean_activation', 'apoz', 'taylorfo'): if args.pruner in ('mean_activation', 'apoz', 'taylorfo'):
kw_args['sparsity_training_epochs'] = 1
if args.pruner == 'slim':
kw_args['sparsity_training_epochs'] = 5 kw_args['sparsity_training_epochs'] = 5
if args.pruner == 'agp': if args.pruner == 'agp':
...@@ -268,10 +271,11 @@ def main(args): ...@@ -268,10 +271,11 @@ def main(args):
if args.test_only: if args.test_only:
test(args, model, device, criterion, test_loader) test(args, model, device, criterion, test_loader)
# Unwrap all modules to normal state if args.speed_up:
pruner._unwrap_model() # Unwrap all modules to normal state
m_speedup = ModelSpeedup(model, dummy_input, mask_path, device) pruner._unwrap_model()
m_speedup.speedup_model() m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
m_speedup.speedup_model()
print('start finetuning...') print('start finetuning...')
best_top1 = 0 best_top1 = 0
...@@ -332,6 +336,10 @@ if __name__ == '__main__': ...@@ -332,6 +336,10 @@ if __name__ == '__main__':
'fpgm', 'mean_activation', 'apoz', 'taylorfo'], 'fpgm', 'mean_activation', 'apoz', 'taylorfo'],
help='pruner to use') help='pruner to use')
# speed-up
parser.add_argument('--speed-up', action='store_true', default=False,
help='Whether to speed-up the pruned model')
# fine-tuning # fine-tuning
parser.add_argument('--fine-tune-epochs', type=int, default=160, parser.add_argument('--fine-tune-epochs', type=int, default=160,
help='epochs to fine tune') help='epochs to fine tune')
......
...@@ -526,6 +526,7 @@ class ActivationAPoZRankFilterPruner(IterativePruner): ...@@ -526,6 +526,7 @@ class ActivationAPoZRankFilterPruner(IterativePruner):
super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, trainer=trainer, super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, trainer=trainer,
criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input, criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input,
activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs) activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs)
self.patch_optimizer(self.update_mask)
def _supported_dependency_aware(self): def _supported_dependency_aware(self):
return True return True
...@@ -571,6 +572,7 @@ class ActivationMeanRankFilterPruner(IterativePruner): ...@@ -571,6 +572,7 @@ class ActivationMeanRankFilterPruner(IterativePruner):
super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, trainer=trainer, super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, trainer=trainer,
criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input, criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input,
activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs) activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs)
self.patch_optimizer(self.update_mask)
def _supported_dependency_aware(self): def _supported_dependency_aware(self):
return True return True
...@@ -26,7 +26,7 @@ echo 'testing level pruner pruning' ...@@ -26,7 +26,7 @@ echo 'testing level pruner pruning'
python3 basic_pruners_torch.py --pruner level --pretrain-epochs 1 --fine-tune-epochs 1 --model lenet --dataset mnist python3 basic_pruners_torch.py --pruner level --pretrain-epochs 1 --fine-tune-epochs 1 --model lenet --dataset mnist
echo 'testing agp pruning' echo 'testing agp pruning'
python3 basic_pruners_torch.py --pruner agp --pretrain-epochs 1 --fine-tune-epochs 1 --model lenet --dataset mnist python3 basic_pruners_torch.py --pruner agp --pretrain-epochs 1 --fine-tune-epochs 1 --model vgg16 --dataset cifar10
echo 'testing mean_activation pruning' echo 'testing mean_activation pruning'
python3 basic_pruners_torch.py --pruner mean_activation --pretrain-epochs 1 --fine-tune-epochs 1 --model vgg16 --dataset cifar10 python3 basic_pruners_torch.py --pruner mean_activation --pretrain-epochs 1 --fine-tune-epochs 1 --model vgg16 --dataset cifar10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment