"ts/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "e99c579b19c8685cad01f1d2136f3c21cadf0be6"
Unverified Commit d1450b4d authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix compression pipeline (#3678)

parent c05a9228
......@@ -230,7 +230,10 @@ def main(args):
kw_args['optimizer'] = optimizer
kw_args['criterion'] = criterion
if args.pruner in ('slim', 'mean_activation', 'apoz', 'taylorfo'):
if args.pruner in ('mean_activation', 'apoz', 'taylorfo'):
kw_args['sparsity_training_epochs'] = 1
if args.pruner == 'slim':
kw_args['sparsity_training_epochs'] = 5
if args.pruner == 'agp':
......@@ -268,10 +271,11 @@ def main(args):
if args.test_only:
test(args, model, device, criterion, test_loader)
# Unwrap all modules to normal state
pruner._unwrap_model()
m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
m_speedup.speedup_model()
if args.speed_up:
# Unwrap all modules to normal state
pruner._unwrap_model()
m_speedup = ModelSpeedup(model, dummy_input, mask_path, device)
m_speedup.speedup_model()
print('start finetuning...')
best_top1 = 0
......@@ -332,6 +336,10 @@ if __name__ == '__main__':
'fpgm', 'mean_activation', 'apoz', 'taylorfo'],
help='pruner to use')
# speed-up
parser.add_argument('--speed-up', action='store_true', default=False,
help='Whether to speed-up the pruned model')
# fine-tuning
parser.add_argument('--fine-tune-epochs', type=int, default=160,
help='epochs to fine tune')
......
......@@ -526,6 +526,7 @@ class ActivationAPoZRankFilterPruner(IterativePruner):
super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, trainer=trainer,
criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input,
activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs)
self.patch_optimizer(self.update_mask)
def _supported_dependency_aware(self):
return True
......@@ -571,6 +572,7 @@ class ActivationMeanRankFilterPruner(IterativePruner):
super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, trainer=trainer,
criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input,
activation=activation, num_iterations=1, epochs_per_iteration=sparsity_training_epochs)
self.patch_optimizer(self.update_mask)
def _supported_dependency_aware(self):
return True
......@@ -26,7 +26,7 @@ echo 'testing level pruner pruning'
python3 basic_pruners_torch.py --pruner level --pretrain-epochs 1 --fine-tune-epochs 1 --model lenet --dataset mnist
echo 'testing agp pruning'
python3 basic_pruners_torch.py --pruner agp --pretrain-epochs 1 --fine-tune-epochs 1 --model lenet --dataset mnist
python3 basic_pruners_torch.py --pruner agp --pretrain-epochs 1 --fine-tune-epochs 1 --model vgg16 --dataset cifar10
echo 'testing mean_activation pruning'
python3 basic_pruners_torch.py --pruner mean_activation --pretrain-epochs 1 --fine-tune-epochs 1 --model vgg16 --dataset cifar10
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment