"git@developer.sourcefind.cn:change/sglang.git" did not exist on "df246e699d2a18873da2b2c47b432d07b17d8cca"
Commit fa24368f authored by Supriya Rao's avatar Supriya Rao Committed by Facebook GitHub Bot
Browse files

Update callsites in mobile-vision

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/124

Update callsites from torch.quantization to torch.ao.quantization

Reviewed By: z-a-f, jerryzh168

Differential Revision: D31286125

fbshipit-source-id: ef24ca87d8db398c65bb5b89f035afe0423a5685
parent 9dc1600b
......@@ -26,7 +26,7 @@ import logging
import os
import sys
from abc import ABC, abstractmethod
from typing import Callable, Dict, NamedTuple, Optional, Union
from typing import Callable, Dict, NamedTuple, Optional, Union, Tuple
if sys.version_info >= (3, 8):
from typing import final
......@@ -39,7 +39,6 @@ else:
import torch
import torch.nn as nn
import torch.quantization.quantize_fx
from d2go.modeling.quantization import post_training_quantize
from detectron2.utils.file_io import PathManager
from mobile_cv.arch.utils import fuse_utils
......@@ -52,6 +51,13 @@ from mobile_cv.predictor.builtin_functions import (
NaiveRunFunc,
)
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
from torch.ao.quantization import convert
from torch.ao.quantization.quantize_fx import convert_fx
else:
from torch.quantization import convert
from torch.quantization.quantize_fx import convert_fx
logger = logging.getLogger(__name__)
......@@ -108,13 +114,13 @@ def convert_predictor(
if cfg.QUANTIZATION.EAGER_MODE:
# TODO(T93870278): move this logic to prepare_for_quant_convert
pytorch_model = torch.quantization.convert(pytorch_model, inplace=False)
pytorch_model = convert(pytorch_model, inplace=False)
else: # FX graph mode quantization
if hasattr(pytorch_model, "prepare_for_quant_convert"):
pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
else:
# TODO(T93870381): move this to a default function
pytorch_model = torch.quantization.quantize_fx.convert_fx(pytorch_model)
pytorch_model = convert_fx(pytorch_model)
logger.info("Quantized Model:\n{}".format(pytorch_model))
else:
......
......@@ -24,7 +24,7 @@ from mobile_cv.arch.utils.quantize_utils import (
QuantWrapper,
)
from mobile_cv.predictor.api import FuncInfo
from torch.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx
from torch.ao.quantization.quantize_fx import prepare_fx, prepare_qat_fx, convert_fx
logger = logging.getLogger(__name__)
......@@ -235,9 +235,9 @@ def default_rcnn_prepare_for_quant(self, cfg):
model = self
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
model.qconfig = (
torch.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
torch.ao.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
if model.training
else torch.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
else torch.ao.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
)
if (
hasattr(model, "roi_heads")
......
......@@ -6,13 +6,18 @@ import contextlib
import copy
import inspect
import logging
from typing import Tuple
import torch
import torch.quantization.quantize_fx
from detectron2.checkpoint import DetectionCheckpointer
from mobile_cv.arch.utils import fuse_utils
from mobile_cv.common.misc.iter_utils import recursive_iterate
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 10):
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
else:
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
logger = logging.getLogger(__name__)
......@@ -148,7 +153,7 @@ def _cast_detection_model(model, device):
def add_d2_quant_mapping(mappings):
"""HACK: Add d2 specific module mapping for eager model quantization"""
import torch.quantization.quantization_mappings as qm
import torch.ao.quantization.quantization_mappings as qm
for k, v in mappings.items():
if k not in qm.get_default_static_quant_module_mappings():
......@@ -218,7 +223,7 @@ def default_prepare_for_quant(cfg, model):
- QAT/PTQ can be determined by model.training.
- Currently the input model can be changed inplace since we won't re-use the
input model.
- Currently this API doesn't include the final torch.quantization.prepare(_qat)
- Currently this API doesn't include the final torch.ao.quantization.prepare(_qat)
call since existing usecases don't have further steps after it.
Args:
......@@ -229,9 +234,9 @@ def default_prepare_for_quant(cfg, model):
nn.Module: a ready model for QAT training or PTQ calibration
"""
qconfig = (
torch.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
torch.ao.quantization.get_default_qat_qconfig(cfg.QUANTIZATION.BACKEND)
if model.training
else torch.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
else torch.ao.quantization.get_default_qconfig(cfg.QUANTIZATION.BACKEND)
)
if cfg.QUANTIZATION.EAGER_MODE:
......@@ -239,14 +244,14 @@ def default_prepare_for_quant(cfg, model):
torch.backends.quantized.engine = cfg.QUANTIZATION.BACKEND
model.qconfig = qconfig
# TODO(future diff): move the torch.quantization.prepare(...) call
# TODO(future diff): move the torch.ao.quantization.prepare(...) call
# here, to be consistent with the FX branch
else: # FX graph mode quantization
qconfig_dict = {"": qconfig}
if model.training:
model = torch.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
model = prepare_qat_fx(model, qconfig_dict)
else:
model = torch.quantization.quantize_fx.prepare_fx(model, qconfig_dict)
model = prepare_fx(model, qconfig_dict)
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
......@@ -254,7 +259,7 @@ def default_prepare_for_quant(cfg, model):
def default_prepare_for_quant_convert(cfg, model):
return torch.quantization.quantize_fx.convert_fx(model)
return convert_fx(model)
@mock_quantization_type
......@@ -273,7 +278,7 @@ def post_training_quantize(cfg, model, data_loader):
model = default_prepare_for_quant(cfg, model)
if cfg.QUANTIZATION.EAGER_MODE:
torch.quantization.prepare(model, inplace=True)
torch.ao.quantization.prepare(model, inplace=True)
logger.info("Prepared the PTQ model for calibration:\n{}".format(model))
# Option for forcing running calibration on GPU, works only when the model supports
......@@ -329,7 +334,7 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
model = default_prepare_for_quant(cfg, model)
# TODO(future diff): move this into prepare_for_quant to match FX branch
torch.quantization.prepare_qat(model, inplace=True)
torch.ao.quantization.prepare_qat(model, inplace=True)
else: # FX graph mode quantization
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
......@@ -342,10 +347,10 @@ def setup_qat_model(cfg, model, enable_fake_quant=False, enable_observer=False):
if not enable_fake_quant:
logger.info("Disabling fake quant ...")
model.apply(torch.quantization.disable_fake_quant)
model.apply(torch.ao.quantization.disable_fake_quant)
if not enable_observer:
logger.info("Disabling observer ...")
model.apply(torch.quantization.disable_observer)
model.apply(torch.ao.quantization.disable_observer)
# fuse_model and prepare_qat may change the state_dict of model, keep a map from the
# orginal model to the key QAT in order to load weight from non-QAT model.
......
......@@ -13,15 +13,15 @@ from mobile_cv.arch.quantization.observer import update_stat as observer_update_
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_info
from torch.quantization import ( # @manual
from torch.ao.quantization import ( # @manual
QConfig,
QConfigDynamic,
QuantType,
get_default_qat_qconfig,
get_default_qconfig,
)
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from torch.quantization.utils import get_quant_type
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from torch.ao.quantization.utils import get_quant_type
QConfigDicts = Dict[str, Dict[str, Union[QConfig, QConfigDynamic]]]
......@@ -355,12 +355,12 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
# Enabled by default, so the assumption for > 0 is that the
# user wants it disabled then enabled.
ModelTransform(
fn=torch.quantization.disable_fake_quant,
fn=torch.ao.quantization.disable_fake_quant,
step=0,
message="Disable fake quant",
),
ModelTransform(
fn=torch.quantization.enable_fake_quant,
fn=torch.ao.quantization.enable_fake_quant,
step=start_step,
message="Enable fake quant to start QAT",
),
......@@ -371,12 +371,12 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
# See comment for start_step above.
[
ModelTransform(
fn=torch.quantization.disable_observer,
fn=torch.ao.quantization.disable_observer,
step=0,
message="Disable observer",
),
ModelTransform(
fn=torch.quantization.enable_observer,
fn=torch.ao.quantization.enable_observer,
step=start_observer,
message="Start observer",
),
......@@ -385,7 +385,7 @@ class QuantizationAwareTraining(Callback, QuantizationMixin):
if end_observer is not None:
self.transforms.append(
ModelTransform(
fn=torch.quantization.disable_observer,
fn=torch.ao.quantization.disable_observer,
step=end_observer,
message="End observer",
)
......
......@@ -630,7 +630,7 @@ class Detectron2GoRunner(BaseRunner):
trainer.iter
)
)
trainer.model.apply(torch.quantization.enable_fake_quant)
trainer.model.apply(torch.ao.quantization.enable_fake_quant)
applied["enable_fake_quant"] = True
if cfg.QUANTIZATION.QAT.BATCH_SIZE_FACTOR != 1.0:
......@@ -661,7 +661,7 @@ class Detectron2GoRunner(BaseRunner):
and trainer.iter < cfg.QUANTIZATION.QAT.DISABLE_OBSERVER_ITER
):
logger.info("[QAT] enable observer, iter = {}".format(trainer.iter))
trainer.model.apply(torch.quantization.enable_observer)
trainer.model.apply(torch.ao.quantization.enable_observer)
applied["enable_observer"] = True
if (
......@@ -673,7 +673,7 @@ class Detectron2GoRunner(BaseRunner):
trainer.iter
)
)
trainer.model.apply(torch.quantization.disable_observer)
trainer.model.apply(torch.ao.quantization.disable_observer)
applied["disable_observer"] = True
if (
......
......@@ -6,7 +6,7 @@ import torch
from d2go.utils.testing.data_loader_helper import create_local_dataset
from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.structures import Boxes, ImageList, Instances
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
@META_ARCH_REGISTRY.register()
......@@ -54,7 +54,7 @@ class DetMetaArchForTest(torch.nn.Module):
def prepare_for_quant(self, cfg):
self.avgpool = prepare_qat_fx(
self.avgpool,
{"": torch.quantization.get_default_qat_qconfig()},
{"": torch.ao.quantization.get_default_qat_qconfig()},
)
return self
......
......@@ -20,11 +20,11 @@ from d2go.utils.testing.helper import tempdir
from d2go.utils.testing.lightning_test_module import TestModule
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch.quantization import ( # @manual; @manual
from torch.ao.quantization import ( # @manual; @manual
default_dynamic_qconfig,
get_default_qconfig,
)
from torch.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
class TestUtilities(unittest.TestCase):
......
......@@ -21,7 +21,7 @@ from detectron2.modeling import META_ARCH_REGISTRY
from detectron2.utils.events import EventStorage
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import Tensor
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
class TestLightningTask(unittest.TestCase):
......@@ -175,7 +175,7 @@ class TestLightningTask(unittest.TestCase):
def prepare_for_quant(self, cfg):
self.avgpool = prepare_qat_fx(
self.avgpool,
{"": torch.quantization.get_default_qat_qconfig()},
{"": torch.ao.quantization.get_default_qat_qconfig()},
self.custom_config_dict,
)
return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment