Commit 97904ba4 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

prepare_for_quant_convert -> custom_covert_fx

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/325

`prepare_for_quant_convert` is a confusing name because it only does `convert`, there's no "prepare" in it. It's actually for fx only, because eager mode always calls `torch.quantization.convert`, there's no way to customize it. So just call this `custom_convert_fx`. The new name is also unique in fbcode, so easy to do codemod later on.

This diff simply does the renaming by biggrep + replace.

Reviewed By: jerryzh168

Differential Revision: D37676717

fbshipit-source-id: e7d05eaafddc383dd432986267c945c8ebf94df4
parent c7226783
...@@ -35,7 +35,7 @@ class CfgNode(_CfgNode): ...@@ -35,7 +35,7 @@ class CfgNode(_CfgNode):
@classmethod @classmethod
def cast_from_other_class(cls, other_cfg): def cast_from_other_class(cls, other_cfg):
"""Cast an instance of other CfgNode to D2Go's CfgNode (or its subclass)""" """Cast an instance of other CfgNode to D2Go's CfgNode (or its subclass)"""
new_cfg = CfgNode(other_cfg) new_cfg = cls(other_cfg)
# copy all fields inside __dict__, this will preserve fields like __deprecated_keys__ # copy all fields inside __dict__, this will preserve fields like __deprecated_keys__
for k, v in other_cfg.__dict__.items(): for k, v in other_cfg.__dict__.items():
new_cfg.__dict__[k] = v new_cfg.__dict__[k] = v
......
...@@ -76,8 +76,8 @@ def convert_quantized_model( ...@@ -76,8 +76,8 @@ def convert_quantized_model(
logger.warn("Post training quantized model has bn inside fused ops") logger.warn("Post training quantized model has bn inside fused ops")
logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...") logger.info(f"Converting quantized model {cfg.QUANTIZATION.BACKEND}...")
if hasattr(pytorch_model, "prepare_for_quant_convert"): if hasattr(pytorch_model, "custom_convert_fx"):
pytorch_model = pytorch_model.prepare_for_quant_convert(cfg) pytorch_model = pytorch_model.custom_convert_fx(cfg)
else: else:
# TODO(T93870381): move this to a default function # TODO(T93870381): move this to a default function
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
......
...@@ -43,7 +43,7 @@ logger = logging.getLogger(__name__) ...@@ -43,7 +43,7 @@ logger = logging.getLogger(__name__)
# of registries might be over-kill. # of registries might be over-kill.
RCNN_PREPARE_FOR_EXPORT_REGISTRY = Registry("RCNN_PREPARE_FOR_EXPORT") RCNN_PREPARE_FOR_EXPORT_REGISTRY = Registry("RCNN_PREPARE_FOR_EXPORT")
RCNN_PREPARE_FOR_QUANT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT") RCNN_PREPARE_FOR_QUANT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT")
RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY = Registry("RCNN_PREPARE_FOR_QUANT_CONVERT") RCNN_CUSTOM_CONVERT_FX_REGISTRY = Registry("RCNN_CUSTOM_CONVERT_FX")
# Re-register D2's meta-arch in D2Go with updated APIs # Re-register D2's meta-arch in D2Go with updated APIs
...@@ -57,10 +57,8 @@ class GeneralizedRCNN(_GeneralizedRCNN): ...@@ -57,10 +57,8 @@ class GeneralizedRCNN(_GeneralizedRCNN):
func = RCNN_PREPARE_FOR_QUANT_REGISTRY.get(cfg.RCNN_PREPARE_FOR_QUANT) func = RCNN_PREPARE_FOR_QUANT_REGISTRY.get(cfg.RCNN_PREPARE_FOR_QUANT)
return func(self, cfg, *args, **kwargs) return func(self, cfg, *args, **kwargs)
def prepare_for_quant_convert(self, cfg, *args, **kwargs): def custom_convert_fx(self, cfg, *args, **kwargs):
func = RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY.get( func = RCNN_CUSTOM_CONVERT_FX_REGISTRY.get(cfg.RCNN_CUSTOM_CONVERT_FX)
cfg.RCNN_PREPARE_FOR_QUANT_CONVERT
)
return func(self, cfg, *args, **kwargs) return func(self, cfg, *args, **kwargs)
def _cast_model_to_device(self, device): def _cast_model_to_device(self, device):
...@@ -329,8 +327,8 @@ def default_rcnn_prepare_for_quant(self, cfg, example_input=None): ...@@ -329,8 +327,8 @@ def default_rcnn_prepare_for_quant(self, cfg, example_input=None):
return model return model
@RCNN_PREPARE_FOR_QUANT_CONVERT_REGISTRY.register() @RCNN_CUSTOM_CONVERT_FX_REGISTRY.register()
def default_rcnn_prepare_for_quant_convert(self, cfg): def default_rcnn_custom_convert_fx(self, cfg):
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
convert(self, inplace=True) convert(self, inplace=True)
else: else:
......
...@@ -265,7 +265,7 @@ def default_prepare_for_quant(cfg, model, example_input=None): ...@@ -265,7 +265,7 @@ def default_prepare_for_quant(cfg, model, example_input=None):
return model return model
def default_prepare_for_quant_convert(cfg, model): def default_custom_convert_fx(cfg, model):
return convert_fx(model) return convert_fx(model)
......
...@@ -206,8 +206,8 @@ class QuantizationMixin(ABC): ...@@ -206,8 +206,8 @@ class QuantizationMixin(ABC):
Returns: Returns:
The quantized model. The quantized model.
""" """
if hasattr(root, "prepare_for_quant_convert"): if hasattr(root, "custom_convert_fx"):
return root.prepare_for_quant_convert() return root.custom_convert_fx()
old_attrs = { old_attrs = {
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr) attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
} }
......
...@@ -108,7 +108,8 @@ def _add_rcnn_default_config(_C: CN) -> None: ...@@ -108,7 +108,8 @@ def _add_rcnn_default_config(_C: CN) -> None:
_C.RCNN_PREPARE_FOR_EXPORT = "default_rcnn_prepare_for_export" _C.RCNN_PREPARE_FOR_EXPORT = "default_rcnn_prepare_for_export"
_C.RCNN_PREPARE_FOR_QUANT = "default_rcnn_prepare_for_quant" _C.RCNN_PREPARE_FOR_QUANT = "default_rcnn_prepare_for_quant"
_C.RCNN_PREPARE_FOR_QUANT_CONVERT = "default_rcnn_prepare_for_quant_convert" _C.RCNN_CUSTOM_CONVERT_FX = "default_rcnn_custom_convert_fx"
_C.register_deprecated_key("RCNN_PREPARE_FOR_QUANT_CONVERT")
def get_base_runner_default_cfg(cfg: CN) -> CN: def get_base_runner_default_cfg(cfg: CN) -> CN:
......
...@@ -18,8 +18,8 @@ from d2go.modeling.api import build_meta_arch ...@@ -18,8 +18,8 @@ from d2go.modeling.api import build_meta_arch
from d2go.modeling.model_freezing_utils import set_requires_grad from d2go.modeling.model_freezing_utils import set_requires_grad
from d2go.optimizer import build_optimizer_mapper from d2go.optimizer import build_optimizer_mapper
from d2go.quantization.modeling import ( from d2go.quantization.modeling import (
default_custom_convert_fx,
default_prepare_for_quant, default_prepare_for_quant,
default_prepare_for_quant_convert,
) )
from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED from d2go.runner.callbacks.quantization import maybe_prepare_for_quantization, PREPARED
from d2go.runner.default_runner import ( from d2go.runner.default_runner import (
...@@ -483,11 +483,11 @@ class DefaultTask(pl.LightningModule): ...@@ -483,11 +483,11 @@ class DefaultTask(pl.LightningModule):
self.model = default_prepare_for_quant(self.cfg, self.model, example_input) self.model = default_prepare_for_quant(self.cfg, self.model, example_input)
return self return self
def prepare_for_quant_convert(self) -> pl.LightningModule: def custom_convert_fx(self) -> pl.LightningModule:
if hasattr(self.model, "prepare_for_quant_convert"): if hasattr(self.model, "custom_convert_fx"):
self.model = self.model.prepare_for_quant_convert(self.cfg) self.model = self.model.custom_convert_fx(self.cfg)
else: else:
self.model = default_prepare_for_quant_convert(self.cfg, self.model) self.model = default_custom_convert_fx(self.cfg, self.model)
return self return self
......
...@@ -67,7 +67,7 @@ class DetMetaArchForTest(torch.nn.Module): ...@@ -67,7 +67,7 @@ class DetMetaArchForTest(torch.nn.Module):
) )
return self return self
def prepare_for_quant_convert(self, cfg): def custom_convert_fx(self, cfg):
self.avgpool = convert_fx(self.avgpool) self.avgpool = convert_fx(self.avgpool)
return self return self
......
...@@ -184,7 +184,7 @@ class TestLightningTask(unittest.TestCase): ...@@ -184,7 +184,7 @@ class TestLightningTask(unittest.TestCase):
) )
return self return self
def prepare_for_quant_convert(self, cfg): def custom_convert_fx(self, cfg):
self.avgpool = convert_fx( self.avgpool = convert_fx(
self.avgpool, convert_custom_config=self.custom_config_dict self.avgpool, convert_custom_config=self.custom_config_dict
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment