Unverified Commit a9668347 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[BugFix] fix compression bugs (#5140)

parent 56c6cfea
...@@ -64,7 +64,7 @@ Usage ...@@ -64,7 +64,7 @@ Usage
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import SGD from torch.optim import SGD
from scripts.compression_mnist_model import TorchModel, device, trainer, evaluator, test_trt from nni_assets.compression.mnist_model import TorchModel, device, trainer, evaluator, test_trt
config_list = [{ config_list = [{
'quant_types': ['input', 'weight'], 'quant_types': ['input', 'weight'],
......
2404b8d0c3958a0191b77bbe882456e4 06c37bd5c886478ae20a1fc552af729a
\ No newline at end of file \ No newline at end of file
...@@ -84,7 +84,7 @@ Usage ...@@ -84,7 +84,7 @@ Usage
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import SGD from torch.optim import SGD
from scripts.compression_mnist_model import TorchModel, device, trainer, evaluator, test_trt from nni_assets.compression.mnist_model import TorchModel, device, trainer, evaluator, test_trt
config_list = [{ config_list = [{
'quant_types': ['input', 'weight'], 'quant_types': ['input', 'weight'],
...@@ -174,9 +174,9 @@ finetuning the model by using QAT ...@@ -174,9 +174,9 @@ finetuning the model by using QAT
.. code-block:: none .. code-block:: none
Average test loss: 0.5386, Accuracy: 8619/10000 (86%) Average test loss: 0.6058, Accuracy: 8534/10000 (85%)
Average test loss: 0.1553, Accuracy: 9521/10000 (95%) Average test loss: 0.1585, Accuracy: 9508/10000 (95%)
Average test loss: 0.1001, Accuracy: 9686/10000 (97%) Average test loss: 0.0920, Accuracy: 9717/10000 (97%)
...@@ -207,7 +207,7 @@ export model and get calibration_config ...@@ -207,7 +207,7 @@ export model and get calibration_config
.. code-block:: none .. code-block:: none
calibration_config: {'conv1': {'weight_bits': 8, 'weight_scale': tensor([0.0029], device='cuda:0'), 'weight_zero_point': tensor([98.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': -0.4242129623889923, 'tracked_max_input': 2.821486711502075}, 'conv2': {'weight_bits': 8, 'weight_scale': tensor([0.0017], device='cuda:0'), 'weight_zero_point': tensor([124.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 8.848002433776855}, 'fc1': {'weight_bits': 8, 'weight_scale': tensor([0.0010], device='cuda:0'), 'weight_zero_point': tensor([134.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 14.64758586883545}, 'fc2': {'weight_bits': 8, 'weight_scale': tensor([0.0013], device='cuda:0'), 'weight_zero_point': tensor([121.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 15.807988166809082}, 'relu1': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 9.041301727294922}, 'relu2': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 15.143928527832031}, 'relu3': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 16.151935577392578}, 'relu4': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 11.749024391174316}} calibration_config: {'conv1': {'weight_bits': 8, 'weight_scale': tensor([0.0029], device='cuda:0'), 'weight_zero_point': tensor([97.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': -0.4242129623889923, 'tracked_max_input': 2.821486711502075}, 'conv2': {'weight_bits': 8, 'weight_scale': tensor([0.0017], device='cuda:0'), 'weight_zero_point': tensor([115.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 7.800363063812256}, 'fc1': {'weight_bits': 8, 'weight_scale': tensor([0.0010], device='cuda:0'), 'weight_zero_point': tensor([121.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 13.914573669433594}, 'fc2': {'weight_bits': 8, 'weight_scale': tensor([0.0012], device='cuda:0'), 'weight_zero_point': tensor([125.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 11.657418251037598}, 'relu1': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 7.897384166717529}, 'relu2': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 14.337020874023438}, 'relu3': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 11.884227752685547}, 'relu4': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 9.330422401428223}}
...@@ -237,8 +237,8 @@ build tensorRT engine to make a real speedup ...@@ -237,8 +237,8 @@ build tensorRT engine to make a real speedup
.. code-block:: none .. code-block:: none
Loss: 0.10061546401977539 Accuracy: 96.83% Loss: 0.09235906448364258 Accuracy: 97.19%
Inference elapsed_time (whole dataset): 0.04322671890258789s Inference elapsed_time (whole dataset): 0.03632998466491699s
...@@ -300,7 +300,7 @@ input tensor: ``torch.randn(128, 3, 32, 32)`` ...@@ -300,7 +300,7 @@ input tensor: ``torch.randn(128, 3, 32, 32)``
.. rst-class:: sphx-glr-timing .. rst-class:: sphx-glr-timing
**Total running time of the script:** ( 1 minutes 4.509 seconds) **Total running time of the script:** ( 1 minutes 13.658 seconds)
.. _sphx_glr_download_tutorials_quantization_speedup.py: .. _sphx_glr_download_tutorials_quantization_speedup.py:
......
...@@ -5,17 +5,19 @@ ...@@ -5,17 +5,19 @@
Computation times Computation times
================= =================
**00:20.822** total execution time for **tutorials** files: **01:39.686** total execution time for **tutorials** files:
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:20.822 | 0.0 MB | | :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 01:39.686 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 01:51.710 | 0.0 MB | | :ref:`sphx_glr_tutorials_darts.py` (``darts.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_hello_nas.py` (``hello_nas.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_nasbench_as_dataset.py` (``nasbench_as_dataset.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_nasbench_as_dataset.py` (``nasbench_as_dataset.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_bert_glue.py` (``pruning_bert_glue.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_customize.py` (``pruning_customize.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_pruning_customize.py` (``pruning_customize.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py` (``pruning_quick_start_mnist.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_pruning_quick_start_mnist.py` (``pruning_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
...@@ -24,7 +26,5 @@ Computation times ...@@ -24,7 +26,5 @@ Computation times
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_customize.py` (``quantization_customize.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_quantization_customize.py` (``quantization_customize.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_quick_start_mnist.py` (``quantization_quick_start_mnist.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+
| :ref:`sphx_glr_tutorials_quantization_speedup.py` (``quantization_speedup.py``) | 00:00.000 | 0.0 MB | | :ref:`sphx_glr_tutorials_quantization_speedup.py` (``quantization_speedup.py``) | 00:00.000 | 0.0 MB |
+-----------------------------------------------------------------------------------------------------+-----------+--------+ +-----------------------------------------------------------------------------------------------------+-----------+--------+
...@@ -217,7 +217,7 @@ def main(args): ...@@ -217,7 +217,7 @@ def main(args):
}] }]
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = QAT_Quantizer(model, config_list, optimizer) quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input)
quantizer.compress() quantizer.compress()
# Step6. Quantization Aware Training # Step6. Quantization Aware Training
......
...@@ -134,11 +134,11 @@ def main(): ...@@ -134,11 +134,11 @@ def main():
'op_names': ['features.6', 'features.9', 'features.13', 'features.16', 'features.20', 'classifier.2', 'classifier.5'] 'op_names': ['features.6', 'features.9', 'features.13', 'features.16', 'features.20', 'classifier.2', 'classifier.5']
}] }]
quantizer = BNNQuantizer(model, configure_list) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
quantizer = BNNQuantizer(model, configure_list, optimizer)
model = quantizer.compress() model = quantizer.compress()
print('=' * 10 + 'train' + '=' * 10) print('=' * 10 + 'train' + '=' * 10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
best_top1 = 0 best_top1 = 0
for epoch in range(400): for epoch in range(400):
print('# Epoch {} #'.format(epoch)) print('# Epoch {} #'.format(epoch))
......
...@@ -29,7 +29,7 @@ import torch ...@@ -29,7 +29,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import SGD from torch.optim import SGD
from scripts.compression_mnist_model import TorchModel, trainer, evaluator, device from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device
# define the model # define the model
model = TorchModel().to(device) model = TorchModel().to(device)
......
...@@ -43,7 +43,7 @@ Usage ...@@ -43,7 +43,7 @@ Usage
# But in fact ``ModelSpeedup`` is a relatively independent tool, so you can use it independently. # But in fact ``ModelSpeedup`` is a relatively independent tool, so you can use it independently.
import torch import torch
from scripts.compression_mnist_model import TorchModel, device from nni_assets.compression.mnist_model import TorchModel, device
model = TorchModel().to(device) model = TorchModel().to(device)
# masks = {layer_name: {'weight': weight_mask, 'bias': bias_mask}} # masks = {layer_name: {'weight': weight_mask, 'bias': bias_mask}}
......
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import SGD from torch.optim import SGD
from scripts.compression_mnist_model import TorchModel, trainer, evaluator, device, test_trt from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device, test_trt
# define the model # define the model
model = TorchModel().to(device) model = TorchModel().to(device)
......
...@@ -64,7 +64,7 @@ Usage ...@@ -64,7 +64,7 @@ Usage
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.optim import SGD from torch.optim import SGD
from scripts.compression_mnist_model import TorchModel, device, trainer, evaluator, test_trt from nni_assets.compression.mnist_model import TorchModel, device, trainer, evaluator, test_trt
config_list = [{ config_list = [{
'quant_types': ['input', 'weight'], 'quant_types': ['input', 'weight'],
......
...@@ -152,8 +152,12 @@ class ChannelDependency(Dependency): ...@@ -152,8 +152,12 @@ class ChannelDependency(Dependency):
parent_layers = [] parent_layers = []
queue = [] queue = []
queue.append(node) queue.append(node)
visited_set = set()
while queue: while queue:
curnode = queue.pop(0) curnode = queue.pop(0)
if curnode in visited_set:
continue
visited_set.add(curnode)
if curnode.op_type in self.target_types: if curnode.op_type in self.target_types:
# find the first met conv # find the first met conv
parent_layers.append(curnode.name) parent_layers.append(curnode.name)
...@@ -164,6 +168,8 @@ class ChannelDependency(Dependency): ...@@ -164,6 +168,8 @@ class ChannelDependency(Dependency):
parents = self.graph.find_predecessors(curnode.unique_name) parents = self.graph.find_predecessors(curnode.unique_name)
parents = [self.graph.name_to_node[name] for name in parents] parents = [self.graph.name_to_node[name] for name in parents]
for parent in parents: for parent in parents:
if parent in visited_set:
continue
queue.append(parent) queue.append(parent)
return parent_layers return parent_layers
......
...@@ -56,7 +56,7 @@ def rand_like_with_shape(shape, ori_t): ...@@ -56,7 +56,7 @@ def rand_like_with_shape(shape, ori_t):
higher_bound = torch.max(ori_t) higher_bound = torch.max(ori_t)
if dtype in [torch.uint8, torch.int16, torch.short, torch.int16, torch.long, torch.bool]: if dtype in [torch.uint8, torch.int16, torch.short, torch.int16, torch.long, torch.bool]:
return torch.randint(lower_bound, higher_bound+1, shape, dtype=dtype, device=device) return torch.randint(lower_bound.long(), higher_bound.long() + 1, shape, dtype=dtype, device=device)
else: else:
return torch.rand(shape, dtype=dtype, device=device, requires_grad=require_grad) return torch.rand(shape, dtype=dtype, device=device, requires_grad=require_grad)
......
from pathlib import Path from pathlib import Path
root_path = Path(__file__).parent.parent
# define the model # define the model
import torch import torch
from torch import nn from torch import nn
...@@ -38,13 +36,13 @@ device = torch.device("cuda" if use_cuda else "cpu") ...@@ -38,13 +36,13 @@ device = torch.device("cuda" if use_cuda else "cpu")
from torchvision import datasets, transforms from torchvision import datasets, transforms
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.MNIST(root_path / 'data', train=True, download=True, transform=transforms.Compose([ datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=128, shuffle=True) ])), batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader( test_loader = torch.utils.data.DataLoader(
datasets.MNIST(root_path / 'data', train=False, transform=transforms.Compose([ datasets.MNIST('data', train=False, transform=transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) transforms.Normalize((0.1307,), (0.3081,))
])), batch_size=1000, shuffle=True) ])), batch_size=1000, shuffle=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment