Unverified Commit 9e5e0e3c authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix pruner doc typo and example bug (#3223)

parent 0c13ea49
......@@ -582,7 +582,7 @@ PyTorch code
.. code-block:: python
from nni.algorithms.compression.pytorch.pruning import ADMMPruner
from nni.algorithms.compression.pytorch.pruning import AutoCompressPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
......@@ -633,7 +633,7 @@ PyTorch code
You can view :githublink:`example <examples/model_compress/amc/>` for more information.
User configuration for AutoCompress Pruner
User configuration for AMC Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**PyTorch**
......
......@@ -229,7 +229,7 @@ def main(args):
# used to save the performance of the original & pruned & finetuned models
result = {'flops': {}, 'params': {}, 'performance':{}}
flops, params = count_flops_params(model, get_input_size(args.dataset))
flops, params, _ = count_flops_params(model, get_input_size(args.dataset))
result['flops']['original'] = flops
result['params']['original'] = params
......@@ -238,7 +238,7 @@ def main(args):
result['performance']['original'] = evaluation_result
# module types to prune, only "Conv2d" supported for channel pruning
if args.base_algo in ['l1', 'l2']:
if args.base_algo in ['l1', 'l2', 'fpgm']:
op_types = ['Conv2d']
elif args.base_algo == 'level':
op_types = ['default']
......@@ -261,7 +261,7 @@ def main(args):
elif args.pruner == 'ADMMPruner':
# users are free to change the config here
if args.model == 'LeNet':
if args.base_algo in ['l1', 'l2']:
if args.base_algo in ['l1', 'l2', 'fpgm']:
config_list = [{
'sparsity': 0.8,
'op_types': ['Conv2d'],
......@@ -337,7 +337,7 @@ def main(args):
torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
print('Speed up model saved to %s', args.experiment_data_dir)
flops, params = count_flops_params(model, get_input_size(args.dataset))
flops, params, _ = count_flops_params(model, get_input_size(args.dataset))
result['flops']['speedup'] = flops
result['params']['speedup'] = params
......@@ -414,7 +414,7 @@ if __name__ == '__main__':
parser.add_argument('--pruner', type=str, default='SimulatedAnnealingPruner',
help='pruner to use')
parser.add_argument('--base-algo', type=str, default='l1',
help='base pruning algorithm. level, l1 or l2')
help='base pruning algorithm. level, l1, l2, or fpgm')
parser.add_argument('--sparsity', type=float, default=0.1,
help='target overall target sparsity')
# param for SimulatedAnnealingPruner
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment