"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "3512164a978bbb5ff6256fbdeb70625072784c63"
Commit fa24368f authored by Supriya Rao's avatar Supriya Rao Committed by Facebook GitHub Bot
Browse files

Update callsites in mobile-vision

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/124

Update callsites from torch.quantization to torch.ao.quantization

Reviewed By: z-a-f, jerryzh168

Differential Revision: D31286125

fbshipit-source-id: ef24ca87d8db398c65bb5b89f035afe0423a5685
parent 9dc1600b
...@@ -26,7 +26,7 @@ import logging ...@@ -26,7 +26,7 @@ import logging
import os import os
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Dict, NamedTuple, Optional, Union from typing import Callable, Dict, NamedTuple, Optional, Union, Tuple
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
from typing import final from typing import final
...@@ -39,7 +39,6 @@ else: ...@@ -39,7 +39,6 @@ else:
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.quantization.quantize_fx
from d2go.modeling.quantization import post_training_quantize from d2go.modeling.quantization import post_training_quantize
from detectron2.utils.file_io import PathManager from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
...@@ -52,6 +51,13 @@ from mobile_cv.predictor.builtin_functions import ( ...@@ -52,6 +51,13 @@ from mobile_cv.predictor.builtin_functions import (
NaiveRunFunc, NaiveRunFunc,
) )
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
from torch.ao.quantization import convert
from torch.ao.quantization.quantize_fx import convert_fx
else:
from torch.quantization import convert
from torch.quantization.quantize_fx import convert_fx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -108,13 +114,13 @@ def convert_predictor( ...@@ -108,13 +114,13 @@ def convert_predictor(
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
# TODO(T93870278): move this logic to prepare_for_quant_convert # TODO(T93870278): move this logic to prepare_for_quant_convert
pytorch_model = torch.quantization.convert(pytorch_model, inplace=False) pytorch_model = convert(pytorch_model, inplace=False)
else: # FX graph mode quantization else: # FX graph mode quantization
if hasattr(pytorch_model, "prepare_for_quant_convert"): if hasattr(pytorch_model, "prepare_for_quant_convert"):
pytorch_model = pytorch_model.prepare_for_quant_convert(cfg) pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
else: else:
# TODO(T93870381): move this to a default function # TODO(T93870381): move this to a default function
pytorch_model = torch.quantization.quantize_fx.convert_fx(pytorch_model) pytorch_model = convert_fx(pytorch_model)
logger.info("Quantized Model:\n{}".format(pytorch_model)) logger.info("Quantized Model:\n{}".format(pytorch_model))
else: else:
......
...@@ -24,7 +24,7 @@ from mobile_cv.arch.utils.quantize_utils import ( ...@@ -24,7 +24,7 @@ from mobile_cv.arch.utils.quantize_utils import (
QuantWrapper, QuantWrapper,
) )
from mobile_cv.predictor.api import FuncInfo from mobile_cv.predictor.api import FuncInfo
from torch.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx from torch.ao.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -235,9 +235,9 @@ def default_rcnn_prepare_for_quant(self, cfg): ...@@ -235,9 +235,9 @@ def default_rcnn_prepare_for_quant(self, cfg):
model = self model = self
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
model.qconfig = ( model.qconfig = (
torch.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND) torch.ao.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
if model.training if model.training
else torch.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND) else torch.ao.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
) )
if ( if (
hasattr(model, "roi_heads") hasattr(model, "roi_heads")
......
...@@ -6,13 +6,18 @@ import contextlib ...@@ -6,13 +6,18 @@ import contextlib
import copy import copy
import inspect import inspect
import logging import logging
from typing import Tuple
import torch import torch
import torch.quantization.quantize_fx
from detectron2.checkpoint import DetectionCheckpointer from detectron2.checkpoint import DetectionCheckpointer
from mobile_cv.arch.utils import fuse_utils from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate from mobile_cv.common.misc.iter_utils import recursive_iterate
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else:
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -148,7 +153,7 @@ def _cast_detection_model(model, device): ...@@ -148,7 +153,7 @@ def _cast_detection_model(model, device):
def add_d2_quant_mapping(mappings): def add_d2_quant_mapping(mappings):
"""HACK: Add d2 specific module mapping for eager model quantization""" """HACK: Add d2 specific module mapping for eager model quantization"""
import torch.quantization.quantization_mappings as qm import torch.ao.quantization.quantization_mappings as qm
for k, v in mappings.items(): for k, v in mappings.items():
if k not in qm.get_default_static_quant_module_mappings(): if k not in qm.get_default_static_quant_module_mappings():
...@@ -218,7 +223,7 @@ def default_prepare_for_quant(cfg, model): ...@@ -218,7 +223,7 @@ def default_prepare_for_quant(cfg, model):
- QAT/PTQ can be determined by model.training. - QAT/PTQ can be determined by model.training.
- Currently the input model can be changed inplace since we won't re-use the - Currently the input model can be changed inplace since we won't re-use the
input model. input model.
- Currently this API doesn't include the final torch.quantization.prepare(_qat) - Currently this API doesn't include the final torch.ao.quantization.prepare(_qat)
call since existing usecases don't have further steps after it. call since existing usecases don't have further steps after it.
Args: Args:
...@@ -229,9 +234,9 @@ def default_prepare_for_quant(cfg, model): ...@@ -229,9 +234,9 @@ def default_prepare_for_quant(cfg, model):
nn.Module: a ready model for QAT training or PTQ calibration nn.Module: a ready model for QAT training or PTQ calibration
""" """
qconfig = ( qconfig = (
torch.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND) torch.ao.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
if model.training if model.training
else torch.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND) else torch.ao.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
) )
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
...@@ -239,14 +244,14 @@ def default_prepare_for_quant(cfg, model): ...@@ -239,14 +244,14 @@ def default_prepare_for_quant(cfg, model):
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
model.qconfig = qconfig model.qconfig = qconfig
# TODO(future diff): move the torch.quantization.prepare(...) call # TODO(future diff): move the torch.ao.quantization.prepare(...) call
# here, to be consistent with the FX branch # here, to be consistent with the FX branch
else: # FX graph mode quantization else: # FX graph mode quantization
qconfig_dict = {"": qconfig} qconfig_dict = {"": qconfig}
if model.training: if model.training:
model = torch.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict) model = prepare_qat_fx(model, qconfig_dict)
else: else:
model = torch.quantization.quantize_fx.prepare_fx(model, qconfig_dict) model = prepare_fx(model, qconfig_dict)
logger.info("Setup the model with qconfig:\n{}".format(qconfig)) logger.info("Setup the model with qconfig:\n{}".format(qconfig))
...@@ -254,7 +259,7 @@ def default_prepare_for_quant(cfg, model): ...@@ -254,7 +259,7 @@ def default_prepare_for_quant(cfg, model):
def default_prepare_for_quant_convert(cfg, model): def default_prepare_for_quant_convert(cfg, model):
return torch.quantization.quantize_fx.convert_fx(model) return convert_fx(model)
@mock_quantization_type @mock_quantization_type
...@@ -273,7 +278,7 @@ def post_training_quantize(cfg, model, data_loader): ...@@ -273,7 +278,7 @@ def post_training_quantize(cfg, model, data_loader):
model = default_prepare_for_quant(cfg, model) model = default_prepare_for_quant(cfg, model)
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
torch.quantization.prepare(model, inplace=True) torch.ao.quantization.prepare(model, inplace=True)
logger.info("Prepared the PTQ model for calibration:\n{}".format(model)) logger.info("Prepared the PTQ model for calibration:\n{}".format(model))
# Option for forcing running calibration on GPU, works only when the model supports # Option for forcing running calibration on GPU, works only when the model supports
...@@ -329,7 +334,7 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False): ...@@ -329,7 +334,7 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
model = default_prepare_for_quant(cfg, model) model = default_prepare_for_quant(cfg, model)
# TODO(future diff): move this into prepare_for_quant to match FX branch # TODO(future diff): move this into prepare_for_quant to match FX branch
torch.quantization.prepare_qat(model, inplace=True) torch.ao.quantization.prepare_qat(model, inplace=True)
else: # FX graph mode quantization else: # FX graph mode quantization
if hasattr(model, "prepare_for_quant"): if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg) model = model.prepare_for_quant(cfg)
...@@ -342,10 +347,10 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False): ...@@ -342,10 +347,10 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
if not enable_fake_quant: if not enable_fake_quant:
logger.info("Disabling fake quant ...") logger.info("Disabling fake quant ...")
model.apply(torch.quantization.disable_fake_quant) model.apply(torch.ao.quantization.disable_fake_quant)
if not enable_observer: if not enable_observer:
logger.info("Disabling observer ...") logger.info("Disabling observer ...")
model.apply(torch.quantization.disable_observer) model.apply(torch.ao.quantization.disable_observer)
# fuse_model and prepare_qat may change the state_dict of model, keep a map from the # fuse_model and prepare_qat may change the state_dict of model, keep a map from the
# orginal model to the key QAT in order to load weight from non-QAT model. # orginal model to the key QAT in order to load weight from non-QAT model.
......
...@@ -13,15 +13,15 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_ ...@@ -13,15 +13,15 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_
from pytorch_lightning import LightningModule, Trainer from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities import rank_zero_info
from torch.quantization import ( # @manual from torch.ao.quantization import ( # @manual
QConfig, QConfig,
QConfigDynamic, QConfigDynamic,
QuantType, QuantType,
get_default_qat_qconfig, get_default_qat_qconfig,
get_default_qconfig, get_default_qconfig,
) )
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from torch.quantization.utils import get_quant_type from torch.ao.quantization.utils import get_quant_type
QConfigDicts = Dict[str, Dict[str, Union[QConfig, QConfigDynamic]]] QConfigDicts = Dict[str, Dict[str, Union[QConfig, QConfigDynamic]]]
...@@ -355,12 +355,12 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -355,12 +355,12 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
# Enabled by default, so the assumption for > 0 is that the # Enabled by default, so the assumption for > 0 is that the
# user wants it disabled then enabled. # user wants it disabled then enabled.
ModelTransform( ModelTransform(
fn=torch.quantization.disable_fake_quant, fn=torch.ao.quantization.disable_fake_quant,
step=0, step=0,
message="Disable fake quant", message="Disable fake quant",
), ),
ModelTransform( ModelTransform(
fn=torch.quantization.enable_fake_quant, fn=torch.ao.quantization.enable_fake_quant,
step=start_step, step=start_step,
message="Enable fake quant to start QAT", message="Enable fake quant to start QAT",
), ),
...@@ -371,12 +371,12 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -371,12 +371,12 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
# See comment for start_step above. # See comment for start_step above.
[ [
ModelTransform( ModelTransform(
fn=torch.quantization.disable_observer, fn=torch.ao.quantization.disable_observer,
step=0, step=0,
message="Disable observer", message="Disable observer",
), ),
ModelTransform( ModelTransform(
fn=torch.quantization.enable_observer, fn=torch.ao.quantization.enable_observer,
step=start_observer, step=start_observer,
message="Start observer", message="Start observer",
), ),
...@@ -385,7 +385,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin): ...@@ -385,7 +385,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
if end_observer is not None: if end_observer is not None:
self.transforms.append( self.transforms.append(
ModelTransform( ModelTransform(
fn=torch.quantization.disable_observer, fn=torch.ao.quantization.disable_observer,
step=end_observer, step=end_observer,
message="End observer", message="End observer",
) )
......
...@@ -630,7 +630,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -630,7 +630,7 @@ class Detectron2GoRunner(BaseRunner):
trainer.iter trainer.iter
) )
) )
trainer.model.apply(torch.quantization.enable_fake_quant) trainer.model.apply(torch.ao.quantization.enable_fake_quant)
applied["enable_fake_quant"] = True applied["enable_fake_quant"] = True
if cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR != 1.0: if cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR != 1.0:
...@@ -661,7 +661,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -661,7 +661,7 @@ class Detectron2GoRunner(BaseRunner):
and trainer.iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER and trainer.iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
): ):
logger.info("[QAT] enable observer, iter = {}".format(trainer.iter)) logger.info("[QAT] enable observer, iter = {}".format(trainer.iter))
trainer.model.apply(torch.quantization.enable_observer) trainer.model.apply(torch.ao.quantization.enable_observer)
applied["enable_observer"] = True applied["enable_observer"] = True
if ( if (
...@@ -673,7 +673,7 @@ class Detectron2GoRunner(BaseRunner): ...@@ -673,7 +673,7 @@ class Detectron2GoRunner(BaseRunner):
trainer.iter trainer.iter
) )
) )
trainer.model.apply(torch.quantization.disable_observer) trainer.model.apply(torch.ao.quantization.disable_observer)
applied["disable_observer"] = True applied["disable_observer"] = True
if ( if (
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from d2go.utils.testing.data_loader_helper import create_local_dataset from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.modeling import META_ARCH_REGISTRY from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances from detectron2.structures import Boxes, ImageList, Instances
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
...@@ -54,7 +54,7 @@ class DetMetaArchForTest(torch.nn.Module): ...@@ -54,7 +54,7 @@ class DetMetaArchForTest(torch.nn.Module):
def prepare_for_quant(self, cfg): def prepare_for_quant(self, cfg):
self.avgpool = prepare_qat_fx( self.avgpool = prepare_qat_fx(
self.avgpool, self.avgpool,
{"": torch.quantization.get_default_qat_qconfig()}, {"": torch.ao.quantization.get_default_qat_qconfig()},
) )
return self return self
......
...@@ -20,11 +20,11 @@ from d2go.utils.testing.helper import tempdir ...@@ -20,11 +20,11 @@ from d2go.utils.testing.helper import tempdir
from d2go.utils.testing.lightning_test_module import TestModule from d2go.utils.testing.lightning_test_module import TestModule
from pytorch_lightning import Trainer, seed_everything from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch.quantization import ( # @manual; @manual from torch.ao.quantization import ( # @manual; @manual
default_dynamic_qconfig, default_dynamic_qconfig,
get_default_qconfig, get_default_qconfig,
) )
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
class TestUtilities(unittest.TestCase): class TestUtilities(unittest.TestCase):
......
...@@ -21,7 +21,7 @@ from detectron2.modeling import META_ARCH_REGISTRY ...@@ -21,7 +21,7 @@ from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.utils.events import EventStorage from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor from torch import Tensor
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
class TestLightningTask(unittest.TestCase): class TestLightningTask(unittest.TestCase):
...@@ -175,7 +175,7 @@ class TestLightningTask(unittest.TestCase): ...@@ -175,7 +175,7 @@ class TestLightningTask(unittest.TestCase):
def prepare_for_quant(self, cfg): def prepare_for_quant(self, cfg):
self.avgpool = prepare_qat_fx( self.avgpool = prepare_qat_fx(
self.avgpool, self.avgpool,
{"": torch.quantization.get_default_qat_qconfig()}, {"": torch.ao.quantization.get_default_qat_qconfig()},
self.custom_config_dict, self.custom_config_dict,
) )
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment