"...resnet50_tensorflow.git" did not exist on "ec54319106c913c503517ac78882061e8b8a9607"
Unverified Commit 9e5e0e3c authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix pruner doc typo and example bug (#3223)

parent 0c13ea49
...@@ -582,7 +582,7 @@ PyTorch code ...@@ -582,7 +582,7 @@ PyTorch code
.. code-block:: python .. code-block:: python
from nni.algorithms.compression.pytorch.pruning import ADMMPruner from nni.algorithms.compression.pytorch.pruning import AutoCompressPruner
config_list = [{ config_list = [{
'sparsity': 0.5, 'sparsity': 0.5,
'op_types': ['Conv2d'] 'op_types': ['Conv2d']
...@@ -633,7 +633,7 @@ PyTorch code ...@@ -633,7 +633,7 @@ PyTorch code
You can view :githublink:`example <examples/model_compress/amc/>` for more information. You can view :githublink:`example <examples/model_compress/amc/>` for more information.
User configuration for AutoCompress Pruner User configuration for AMC Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**PyTorch** **PyTorch**
......
...@@ -229,7 +229,7 @@ def main(args): ...@@ -229,7 +229,7 @@ def main(args):
# used to save the performance of the original & pruned & finetuned models # used to save the performance of the original & pruned & finetuned models
result = {'flops': {}, 'params': {}, 'performance':{}} result = {'flops': {}, 'params': {}, 'performance':{}}
flops, params = count_flops_params(model, get_input_size(args.dataset)) flops, params, _ = count_flops_params(model, get_input_size(args.dataset))
result['flops']['original'] = flops result['flops']['original'] = flops
result['params']['original'] = params result['params']['original'] = params
...@@ -238,7 +238,7 @@ def main(args): ...@@ -238,7 +238,7 @@ def main(args):
result['performance']['original'] = evaluation_result result['performance']['original'] = evaluation_result
# module types to prune, only "Conv2d" supported for channel pruning # module types to prune, only "Conv2d" supported for channel pruning
if args.base_algo in ['l1', 'l2']: if args.base_algo in ['l1', 'l2', 'fpgm']:
op_types = ['Conv2d'] op_types = ['Conv2d']
elif args.base_algo == 'level': elif args.base_algo == 'level':
op_types = ['default'] op_types = ['default']
...@@ -261,7 +261,7 @@ def main(args): ...@@ -261,7 +261,7 @@ def main(args):
elif args.pruner == 'ADMMPruner': elif args.pruner == 'ADMMPruner':
# users are free to change the config here # users are free to change the config here
if args.model == 'LeNet': if args.model == 'LeNet':
if args.base_algo in ['l1', 'l2']: if args.base_algo in ['l1', 'l2', 'fpgm']:
config_list = [{ config_list = [{
'sparsity': 0.8, 'sparsity': 0.8,
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
...@@ -337,7 +337,7 @@ def main(args): ...@@ -337,7 +337,7 @@ def main(args):
torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth')) torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
print('Speed up model saved to %s', args.experiment_data_dir) print('Speed up model saved to %s', args.experiment_data_dir)
flops, params = count_flops_params(model, get_input_size(args.dataset)) flops, params, _ = count_flops_params(model, get_input_size(args.dataset))
result['flops']['speedup'] = flops result['flops']['speedup'] = flops
result['params']['speedup'] = params result['params']['speedup'] = params
...@@ -414,7 +414,7 @@ if __name__ == '__main__': ...@@ -414,7 +414,7 @@ if __name__ == '__main__':
parser.add_argument('--pruner', type=str, default='SimulatedAnnealingPruner', parser.add_argument('--pruner', type=str, default='SimulatedAnnealingPruner',
help='pruner to use') help='pruner to use')
parser.add_argument('--base-algo', type=str, default='l1', parser.add_argument('--base-algo', type=str, default='l1',
help='base pruning algorithm. level, l1 or l2') help='base pruning algorithm. level, l1, l2, or fpgm')
parser.add_argument('--sparsity', type=float, default=0.1, parser.add_argument('--sparsity', type=float, default=0.1,
help='target overall target sparsity') help='target overall target sparsity')
# param for SimulatedAnnealingPruner # param for SimulatedAnnealingPruner
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment