quantization_quick_start_mnist.py 2.84 KB
Newer Older
J-shang's avatar
J-shang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
"""
Quantization Quickstart
=======================

Quantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.

In NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported.
Here we use `QAT_Quantizer` as an example to show the usage of quantization in NNI.
"""

# %%
# Preparation
# -----------
#
# In this tutorial, we use a simple model and pre-train on MNIST dataset.
# If you are familiar with defining a model and training in pytorch, you can skip directly to `Quantizing Model`_.

import torch
import torch.nn.functional as F
from torch.optim import SGD

22
from scripts.compression_mnist_model import TorchModel, trainer, evaluator, device, test_trt
J-shang's avatar
J-shang committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

# define the model
model = TorchModel().to(device)

# define the optimizer and criterion for pre-training

optimizer = SGD(model.parameters(), 1e-2)
criterion = F.nll_loss

# pre-train and evaluate the model on MNIST dataset
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)

# %%
# Quantizing Model
# ----------------
#
# Initialize a `config_list`.
42
# Detailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/compression_config_list>`.
J-shang's avatar
J-shang committed
43
44
45
46

config_list = [{
    'quant_types': ['input', 'weight'],
    'quant_bits': {'input': 8, 'weight': 8},
47
    'op_types': ['Conv2d']
J-shang's avatar
J-shang committed
48
49
50
}, {
    'quant_types': ['output'],
    'quant_bits': {'output': 8},
51
    'op_types': ['ReLU']
J-shang's avatar
J-shang committed
52
53
54
}, {
    'quant_types': ['input', 'weight'],
    'quant_bits': {'input': 8, 'weight': 8},
55
    'op_names': ['fc1', 'fc2']
J-shang's avatar
J-shang committed
56
57
58
59
60
61
62
63
}]

# %%
# finetuning the model by using QAT
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
dummy_input = torch.rand(32, 1, 28, 28).to(device)
quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input)
quantizer.compress()
64
65
66
67
68
69

# %%
# The model has now been wrapped, and quantization targets ('quant_types' setting in `config_list`)
# will be quantized & dequantized for simulated quantization in the wrapped layers.
# QAT is a training-aware quantizer, it will update scale and zero point during training.

J-shang's avatar
J-shang committed
70
71
72
73
74
75
76
77
78
79
80
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)

# %%
# export model and get calibration_config
model_path = "./log/mnist_model.pth"
calibration_path = "./log/mnist_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)

print("calibration_config: ", calibration_config)
81
82
83
84
85
86
87
88
89

# %%
# build tensorRT engine to make a real speedup, for more information about speedup, please refer :doc:`quantization_speedup`.

from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT
input_shape = (32, 1, 28, 28)
engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=32)
engine.compress()
test_trt(engine)