fx.py 1.09 KB
Newer Older
Jiaxu Zhu's avatar
Jiaxu Zhu committed
1
2
3
4
5
6
7
8
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


from typing import Tuple

import torch
from mobile_cv.common.misc.oss_utils import fb_overwritable
Naveen Suda's avatar
Naveen Suda committed
9
from torch.ao.quantization.quantize import prepare, prepare_qat
Jiaxu Zhu's avatar
Jiaxu Zhu committed
10
11
12
13


TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION > (1, 10):
14
    from torch.ao.quantization.quantize import convert
Jiaxu Zhu's avatar
Jiaxu Zhu committed
15
16
    from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else:
17
    from torch.quantization.quantize import convert
Jiaxu Zhu's avatar
Jiaxu Zhu committed
18
19
20
21
22
23
24
25
26
    from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx


@fb_overwritable()
def get_prepare_fx_fn(cfg, is_qat):
    return prepare_qat_fx if is_qat else prepare_fx


@fb_overwritable()
27
28
29
30
31
def get_convert_fn(cfg, example_inputs=None, qconfig_mapping=None, backend_config=None):
    if cfg.QUANTIZATION.EAGER_MODE:
        return convert
    else:
        return convert_fx
Naveen Suda's avatar
Naveen Suda committed
32
33
34
35
36
37


@fb_overwritable()
def get_prepare_fn(cfg, is_qat):
    if cfg.QUANTIZATION.EAGER_MODE:
        return prepare_qat if is_qat else prepare