quantization_quick_start_mnist.py 2.96 KB
Newer Older
J-shang's avatar
J-shang committed
1
2
3
4
"""
Quantization Quickstart
=======================

5
6
7
8
9
Here is a four-minute video to get you started with model quantization.

..  youtube:: MSfV7AyfiA4
    :align: center

J-shang's avatar
J-shang committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
Quantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.

In NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported.
Here we use `QAT_Quantizer` as an example to show the usage of quantization in NNI.
"""

# %%
# Preparation
# -----------
#
# In this tutorial, we use a simple model and pre-train on MNIST dataset.
# If you are familiar with defining a model and training in pytorch, you can skip directly to `Quantizing Model`_.

import torch
import torch.nn.functional as F
from torch.optim import SGD

27
from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device, test_trt
J-shang's avatar
J-shang committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

# define the model
model = TorchModel().to(device)

# define the optimizer and criterion for pre-training

optimizer = SGD(model.parameters(), 1e-2)
criterion = F.nll_loss

# pre-train and evaluate the model on MNIST dataset
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)

# %%
# Quantizing Model
# ----------------
#
# Initialize a `config_list`.
47
# Detailed about how to write ``config_list`` please refer :doc:`compression config specification <../compression/compression_config_list>`.
J-shang's avatar
J-shang committed
48
49
50
51

config_list = [{
    'quant_types': ['input', 'weight'],
    'quant_bits': {'input': 8, 'weight': 8},
52
    'op_types': ['Conv2d']
J-shang's avatar
J-shang committed
53
54
55
}, {
    'quant_types': ['output'],
    'quant_bits': {'output': 8},
56
    'op_types': ['ReLU']
J-shang's avatar
J-shang committed
57
58
59
}, {
    'quant_types': ['input', 'weight'],
    'quant_bits': {'input': 8, 'weight': 8},
60
    'op_names': ['fc1', 'fc2']
J-shang's avatar
J-shang committed
61
62
63
64
65
66
67
68
}]

# %%
# finetuning the model by using QAT
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
dummy_input = torch.rand(32, 1, 28, 28).to(device)
quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input)
quantizer.compress()
69
70
71
72
73
74

# %%
# The model has now been wrapped, and quantization targets ('quant_types' setting in `config_list`)
# will be quantized & dequantized for simulated quantization in the wrapped layers.
# QAT is a training-aware quantizer, it will update scale and zero point during training.

J-shang's avatar
J-shang committed
75
76
77
78
79
80
81
82
83
84
85
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)

# %%
# export model and get calibration_config
model_path = "./log/mnist_model.pth"
calibration_path = "./log/mnist_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)

print("calibration_config: ", calibration_config)
86
87
88
89
90
91
92
93
94

# %%
# build tensorRT engine to make a real speedup, for more information about speedup, please refer :doc:`quantization_speedup`.

from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT
input_shape = (32, 1, 28, 28)
engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=32)
engine.compress()
test_trt(engine)