Unverified Commit c5255669 authored by colorjam's avatar colorjam Committed by GitHub
Browse files

FLOPs/Params counter refinement (#2632)

parent 00ddf3aa
...@@ -30,7 +30,7 @@ jobs: ...@@ -30,7 +30,7 @@ jobs:
python3 -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user python3 -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorflow==2.2.0 --user python3 -m pip install tensorflow==2.2.0 --user
python3 -m pip install keras==2.4.2 --user python3 -m pip install keras==2.4.2 --user
python3 -m pip install gym onnx peewee --user python3 -m pip install gym onnx peewee thop --user
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx --user python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx --user
sudo apt-get install swig -y sudo apt-get install swig -y
nnictl package install --name=SMAC nnictl package install --name=SMAC
...@@ -59,6 +59,7 @@ jobs: ...@@ -59,6 +59,7 @@ jobs:
python3 -m pip install --upgrade pip setuptools --user python3 -m pip install --upgrade pip setuptools --user
python3 -m pip install pylint==2.3.1 astroid==2.2.5 --user python3 -m pip install pylint==2.3.1 astroid==2.2.5 --user
python3 -m pip install coverage --user python3 -m pip install coverage --user
python3 -m pip install thop --user
echo "##vso[task.setvariable variable=PATH]${HOME}/.local/bin:${PATH}" echo "##vso[task.setvariable variable=PATH]${HOME}/.local/bin:${PATH}"
displayName: 'Install python tools' displayName: 'Install python tools'
- script: | - script: |
......
...@@ -129,5 +129,6 @@ from nni.compression.torch.utils.counter import count_flops_params ...@@ -129,5 +129,6 @@ from nni.compression.torch.utils.counter import count_flops_params
# Given input size (1, 1, 28, 28) # Given input size (1, 1, 28, 28)
flops, params = count_flops_params(model, (1, 1, 28, 28)) flops, params = count_flops_params(model, (1, 1, 28, 28))
# Format output size to M (i.e., 10^6)
print(f'FLOPs: {flops/1e6:.3f}M, Params: {params/1e6:.3f}M) print(f'FLOPs: {flops/1e6:.3f}M, Params: {params/1e6:.3f}M)
``` ```
\ No newline at end of file
...@@ -14,4 +14,5 @@ nbsphinx ...@@ -14,4 +14,5 @@ nbsphinx
schema schema
tensorboard tensorboard
scikit-learn==0.20 scikit-learn==0.20
thop
https://download.pytorch.org/whl/cpu/torch-1.3.1%2Bcpu-cp37-cp37m-linux_x86_64.whl https://download.pytorch.org/whl/cpu/torch-1.3.1%2Bcpu-cp37-cp37m-linux_x86_64.whl
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import logging
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.compression.torch.compressor import PrunerModuleWrapper from nni.compression.torch.compressor import PrunerModuleWrapper
_logger = logging.getLogger(__name__)
try: try:
from thop import profile from thop import profile
except ImportError: except Exception as e:
_logger.warning('Please install thop using command: pip install thop') print('thop is not found, please install the python package: thop')
raise
def count_flops_params(model: nn.Module, input_size, verbose=True): def count_flops_params(model: nn.Module, input_size, verbose=True):
...@@ -61,8 +58,16 @@ def count_flops_params(model: nn.Module, input_size, verbose=True): ...@@ -61,8 +58,16 @@ def count_flops_params(model: nn.Module, input_size, verbose=True):
flops, params = profile(model, inputs=(inputs, ), custom_ops=custom_ops, verbose=verbose) flops, params = profile(model, inputs=(inputs, ), custom_ops=custom_ops, verbose=verbose)
for m in hook_module_list: for m in hook_module_list:
m._buffers.pop("weight_mask") m._buffers.pop("weight_mask")
# Remove registerd buffer on the model, and fixed following issue:
# https://github.com/Lyken17/pytorch-OpCounter/issues/96
for m in model.modules():
if 'total_ops' in m._buffers:
m._buffers.pop("total_ops")
if 'total_params' in m._buffers:
m._buffers.pop("total_params")
return flops, params return flops, params
......
...@@ -15,6 +15,7 @@ jobs: ...@@ -15,6 +15,7 @@ jobs:
python3 -m pip install torch==1.3.1 --user python3 -m pip install torch==1.3.1 --user
python3 -m pip install keras==2.1.6 --user python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow-gpu==1.15.2 tensorflow-estimator==1.15.1 --force --user python3 -m pip install tensorflow-gpu==1.15.2 tensorflow-estimator==1.15.1 --force --user
python3 -m pip install thop --user
sudo apt-get install swig -y sudo apt-get install swig -y
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment