"examples/sparse/vscode:/vscode.git/clone" did not exist on "76ce14b7b41e2f0ee568942df33971516fc95719"
Commit 97904ba4 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

prepare_for_quant_convert -> custom_covert_fx

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/325

`prepare_for_quant_convert` is a confusing name because it only does `convert`, there's no "prepare" in it. It's actually for fx only, because eager mode always calls `torch.quantization.convert`, there's no way to customize it. So just call this `custom_convert_fx`. The new name is also unique in fbcode, so easy to do codemod later on.

This diff simply does the renaming by biggrep + replace.

Reviewed By: jerryzh168

Differential Revision: D37676717

fbshipit-source-id: e7d05eaafddc383dd432986267c945c8ebf94df4
parent c7226783
......@@ -35,7 +35,7 @@ class CfgNode(_CfgNode):
@classmethod
def cast_from_other_class(cls, other_cfg):
"""Cast an instance of other CfgNode to D2Go's CfgNode (or its subclass)"""
new_cfg = CfgNode(other_cfg)
new_cfg = cls(other_cfg)
# copy all fields inside __dict__, this will preserve fields like __deprecated_keys__
for k, v in other_cfg.__dict__.items():
new_cfg.__dict__[k] = v
......
......@@ -76,8 +76,8 @@ def convert_quantized_model(
logger.warn("Post training quantized model has bn inside fused ops")
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
if hasattr(pytorch_model, "prepare_for_quant_convert"):
pytorch_model = pytorch_model.prepare_for_quant_convert(cfg)
if hasattr(pytorch_model, "custom_convert_fx"):
pytorch_model = pytorch_model.custom_convert_fx(cfg)
else:
# TODO(T93870381): move this to a default function
if cfg.QUANTIZATION.EAGER_MODE:
......
......@@ -43,7 +43,7 @@ logger = logging.getLogger(__name__)
# of registries might be over-kill.
RCNN_PREPARE_FOR_EXPORT_REGISTRY = Registry("RCNN_PREPARE_FOR_EXPORT")
RCNN_PREPARE_FOR_QUANT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT")
RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT_CONVERT")
RCNN_CUSTOM_CONVERT_FX_REGISTRY = Registry("RCNN_CUSTOM_CONVERT_FX")
# Re-register D2's meta-arch in D2Go with updated APIs
......@@ -57,10 +57,8 @@ class GeneralizedRCNN(_GeneralizedRCNN):
func = RCNN_PREPARE_FOR_QUANT_REGISTRY.get(cfg.RCNN_PREPARE_FOR_QUANT)
return func(self, cfg, *args, **kwargs)
def prepare_for_quant_convert(self, cfg, *args, **kwargs):
func = RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY.get(
cfg.RCNN_PREPARE_FOR_QUANT_CONVERT
)
def custom_convert_fx(self, cfg, *args, **kwargs):
func = RCNN_CUSTOM_CONVERT_FX_REGISTRY.get(cfg.RCNN_CUSTOM_CONVERT_FX)
return func(self, cfg, *args, **kwargs)
def _cast_model_to_device(self, device):
......@@ -329,8 +327,8 @@ def default_rcnn_prepare_for_quant(self, cfg, example_input=None):
return model
@RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY.register()
def default_rcnn_prepare_for_quant_convert(self, cfg):
@RCNN_CUSTOM_CONVERT_FX_REGISTRY.register()
def default_rcnn_custom_convert_fx(self, cfg):
if cfg.QUANTIZATION.EAGER_MODE:
convert(self, inplace=True)
else:
......
......@@ -265,7 +265,7 @@ def default_prepare_for_quant(cfg, model, example_input=None):
return model
def default_prepare_for_quant_convert(cfg, model):
def default_custom_convert_fx(cfg, model):
return convert_fx(model)
......
......@@ -206,8 +206,8 @@ class QuantizationMixin(ABC):
Returns:
The quantized model.
"""
if hasattr(root, "prepare_for_quant_convert"):
return root.prepare_for_quant_convert()
if hasattr(root, "custom_convert_fx"):
return root.custom_convert_fx()
old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
}
......
......@@ -108,7 +108,8 @@ def _add_rcnn_default_config(_C: CN) -> None:
_C.RCNN_PREPARE_FOR_EXPORT = "default_rcnn_prepare_for_export"
_C.RCNN_PREPARE_FOR_QUANT = "default_rcnn_prepare_for_quant"
_C.RCNN_PREPARE_FOR_QUANT_CONVERT = "default_rcnn_prepare_for_quant_convert"
_C.RCNN_CUSTOM_CONVERT_FX = "default_rcnn_custom_convert_fx"
_C.register_deprecated_key("RCNN_PREPARE_FOR_QUANT_CONVERT")
def get_base_runner_default_cfg(cfg: CN) -> CN:
......
......@@ -18,8 +18,8 @@ from d2go.modeling.api import build_meta_arch
from d2go.modeling.model_freezing_utils import set_requires_grad
from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import (
default_custom_convert_fx,
default_prepare_for_quant,
default_prepare_for_quant_convert,
)
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.runner.default_runner import (
......@@ -483,11 +483,11 @@ class DefaultTask(pl.LightningModule):
self.model = default_prepare_for_quant(self.cfg, self.model, example_input)
return self
def prepare_for_quant_convert(self) -> pl.LightningModule:
if hasattr(self.model, "prepare_for_quant_convert"):
self.model = self.model.prepare_for_quant_convert(self.cfg)
def custom_convert_fx(self) -> pl.LightningModule:
if hasattr(self.model, "custom_convert_fx"):
self.model = self.model.custom_convert_fx(self.cfg)
else:
self.model = default_prepare_for_quant_convert(self.cfg, self.model)
self.model = default_custom_convert_fx(self.cfg, self.model)
return self
......
......@@ -67,7 +67,7 @@ class DetMetaArchForTest(torch.nn.Module):
)
return self
def prepare_for_quant_convert(self, cfg):
def custom_convert_fx(self, cfg):
self.avgpool = convert_fx(self.avgpool)
return self
......
......@@ -184,7 +184,7 @@ class TestLightningTask(unittest.TestCase):
)
return self
def prepare_for_quant_convert(self, cfg):
def custom_convert_fx(self, cfg):
self.avgpool = convert_fx(
self.avgpool, convert_custom_config=self.custom_config_dict
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment