utils.py 1.64 KB
Newer Older
1
2
3
4
import torch
from torch import nn


5
def _replace_relu(module: nn.Module) -> None:
6
7
8
9
10
11
    reassign = {}
    for name, mod in module.named_children():
        _replace_relu(mod)
        # Checking for explicit type instead of instance
        # as we only want to replace modules of the exact type
        # not inherited classes
12
        if type(mod) is nn.ReLU or type(mod) is nn.ReLU6:
13
14
15
16
17
18
            reassign[name] = nn.ReLU(inplace=False)

    for key, value in reassign.items():
        module._modules[key] = value


19
def quantize_model(model: nn.Module, backend: str) -> None:
20
21
22
23
24
25
    _dummy_input_data = torch.rand(1, 3, 299, 299)
    if backend not in torch.backends.quantized.supported_engines:
        raise RuntimeError("Quantized backend not supported ")
    torch.backends.quantized.engine = backend
    model.eval()
    # Make sure that weight qconfig matches that of the serialized models
26
    if backend == "fbgemm":
27
28
29
        model.qconfig = torch.ao.quantization.QConfig(  # type: ignore[assignment]
            activation=torch.ao.quantization.default_observer,
            weight=torch.ao.quantization.default_per_channel_weight_observer,
30
31
        )
    elif backend == "qnnpack":
32
33
        model.qconfig = torch.ao.quantization.QConfig(  # type: ignore[assignment]
            activation=torch.ao.quantization.default_observer, weight=torch.ao.quantization.default_weight_observer
34
        )
35

36
37
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    model.fuse_model()  # type: ignore[operator]
38
    torch.ao.quantization.prepare(model, inplace=True)
39
    model(_dummy_input_data)
40
    torch.ao.quantization.convert(model, inplace=True)
41
42

    return