"src/include/vscode:/vscode.git/clone" did not exist on "29b354e072c302a5bb59fd637d63d51a94721438"
quantization.py 4.59 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from tqdm import tqdm
import torch
import contextlib
import time
import logging

from pytorch_quantization import quant_modules
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib
from pytorch_quantization.tensor_quant import QuantDescriptor
from . import logger as log
from .utils import calc_ips
import dllogger

initialize = quant_modules.initialize
deactivate = quant_modules.deactivate

IPS_METADATA = {"unit": "img/s", "format": ":.2f"}
TIME_METADATA = {"unit": "s", "format": ":.5f"}


def select_default_calib_method(calib_method='histogram'):
    """Set up selected calibration method in whole network"""
    quant_desc_input = QuantDescriptor(calib_method=calib_method)
    quant_nn.QuantConv1d.set_default_quant_desc_input(quant_desc_input)
    quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
    quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
    quant_nn.QuantAdaptiveAvgPool2d.set_default_quant_desc_input(quant_desc_input)


def quantization_setup(calib_method='histogram'):
    """Change network into quantized version "automatically" and selects histogram as default quantization method"""
    select_default_calib_method(calib_method)
    initialize()


def disable_calibration(model):
    """Disables calibration in whole network. Should be run always before running interference."""
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()


def collect_stats(model, data_loader, logger, num_batches):
    """Feed data to the network and collect statistic"""
    if logger is not None:
        logger.register_metric(
            f"calib.total_ips",
            log.PERF_METER(),
            verbosity=dllogger.Verbosity.DEFAULT,
            metadata=IPS_METADATA,
        )
        logger.register_metric(
            f"calib.data_time",
            log.PERF_METER(),
            verbosity=dllogger.Verbosity.DEFAULT,
            metadata=TIME_METADATA,
        )
        logger.register_metric(
            f"calib.compute_latency",
            log.PERF_METER(),
            verbosity=dllogger.Verbosity.DEFAULT,
            metadata=TIME_METADATA,
        )
    # Enable calibrators
    data_iter = enumerate(data_loader)
    if logger is not None:
        data_iter = logger.iteration_generator_wrapper(data_iter, mode='calib')

    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    end = time.time()

    if logger is not None:
        logger.start_calibration()

    for i, (image, _) in data_iter:
        bs = image.size(0)
        data_time = time.time() - end

        model(image.cuda())

        it_time = time.time() - end

        if logger is not None:
            logger.log_metric(f"calib.total_ips", calc_ips(bs, it_time))
            logger.log_metric(f"calib.data_time", data_time)
            logger.log_metric(f"calib.compute_latency", it_time - data_time)

        if i >= num_batches:
            time.sleep(5)
            break

        end = time.time()

    if logger is not None:
        logger.end_calibration()

    logging.disable(logging.WARNING)
    disable_calibration(model)


def compute_amax(model, **kwargs):
    """Loads statistics of data and calculates quantization parameters in whole network"""
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer) and module._calibrator is not None:
            if isinstance(module._calibrator, calib.MaxCalibrator):
                module.load_calib_amax()
            else:
                module.load_calib_amax(**kwargs)
    model.cuda()


def calibrate(model, train_loader, logger, calib_iter=1, percentile=99.99):
    """Calibrates whole network i.e. gathers data for quantization and calculates quantization parameters"""
    model.eval()

    with torch.no_grad():
        collect_stats(model, train_loader, logger, num_batches=calib_iter)
        compute_amax(model, method="percentile", percentile=percentile)

    logging.disable(logging.NOTSET)


@contextlib.contextmanager
def switch_on_quantization(do_quantization=True):
    """Context manager for quantization activation"""
    if do_quantization:
        initialize()
    try:
        yield
    finally:
        if do_quantization:
            deactivate()