Commit 6fc8f066 authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

add is_qat to custom_prepare_fx

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/332

Solve: "we have decoupled qat with `.training` in quantization, maybe we should use some flags in `cfg` instead of checking this attribute here as well"

Reviewed By: jerryzh168

Differential Revision: D37801241

fbshipit-source-id: ed9884d7b462da195ed2e07c42634acfe5beefb2
parent 75176365
......@@ -58,8 +58,8 @@ class GeneralizedRCNN(_GeneralizedRCNN):
func = RCNN_PREPARE_FOR_QUANT_REGISTRY.get(cfg.RCNN_PREPARE_FOR_QUANT)
return func(self, cfg, *args, **kwargs)
def custom_prepare_fx(self, cfg, example_input=None):
return default_rcnn_custom_prepare_fx(self, cfg, example_input)
def custom_prepare_fx(self, cfg, is_qat, example_input=None):
return default_rcnn_custom_prepare_fx(self, cfg, is_qat, example_input)
def custom_convert_fx(self, cfg):
return default_rcnn_custom_convert_fx(self, cfg)
......@@ -262,8 +262,8 @@ def _get_example_rcnn_input(image_tensor_size: int):
return [_get_batch(), _get_batch()]
def _set_qconfig(model, cfg):
model.qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
def _set_qconfig(model, cfg, is_qat):
model.qconfig = set_backend_and_create_qconfig(cfg, is_train=is_qat)
# skip quantization for point rend head
if (
hasattr(model, "roi_heads")
......@@ -277,7 +277,7 @@ def _set_qconfig(model, cfg):
@RCNN_PREPARE_FOR_QUANT_REGISTRY.register()
def default_rcnn_prepare_for_quant(self, cfg):
model = self
_set_qconfig(model, cfg)
_set_qconfig(model, cfg, model.training)
# Modify the model for eager mode
model = _apply_eager_mode_quant(cfg, model)
......@@ -289,14 +289,14 @@ def default_rcnn_prepare_for_quant(self, cfg):
return model
def default_rcnn_custom_prepare_fx(self, cfg, example_input=None):
def default_rcnn_custom_prepare_fx(self, cfg, is_qat, example_input=None):
model = self
_set_qconfig(model, cfg)
_set_qconfig(model, cfg, is_qat)
# construct example input for FX when not provided
if example_input is None:
assert (
model.training
is_qat
), "Currently only (FX mode) QAT requires user-provided `example_input`"
# make sure the image size can be divided by all strides and size_divisibility
......@@ -307,13 +307,13 @@ def default_rcnn_custom_prepare_fx(self, cfg, example_input=None):
example_input = _get_example_rcnn_input(image_tensor_size)
_fx_quant_prepare(model, cfg, example_input)
_fx_quant_prepare(model, cfg, is_qat, example_input)
return model
def _fx_quant_prepare(self, cfg, example_input):
prep_fn = prepare_qat_fx if self.training else prepare_fx
def _fx_quant_prepare(self, cfg, is_qat, example_input):
prep_fn = prepare_qat_fx if is_qat else prepare_fx
qconfig = {"": self.qconfig}
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
with EventStorage() as _: # D2's rcnn requires EventStorage when for loss
......
......@@ -253,7 +253,7 @@ def default_prepare_for_quant(cfg, model):
return model
def default_custom_prepare_fx(cfg, model, example_input=None):
def default_custom_prepare_fx(cfg, model, is_qat, example_input=None):
"""
Similar to default_prepare_for_quant, but for FX graph mode.
......@@ -264,11 +264,15 @@ def default_custom_prepare_fx(cfg, model, example_input=None):
"""
assert not cfg.QUANTIZATION.EAGER_MODE
qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
qconfig = set_backend_and_create_qconfig(cfg, is_train=is_qat)
qconfig_dict = {"": qconfig}
if example_input is None:
example_input = model.example_input
if model.training:
raise NotImplementedError(
"prepare FX requires `example_input`, user should implement this for"
" their own MetaArch."
)
if is_qat:
model = prepare_qat_fx(model, qconfig_dict, (example_input,))
else:
model = prepare_fx(model, qconfig_dict, (example_input,))
......@@ -281,7 +285,7 @@ def default_custom_convert_fx(cfg, model):
return convert_fx(model)
def convert_to_fake_quant_model(cfg, model, example_input=None):
def convert_to_fake_quant_model(cfg, model, is_qat, example_input=None):
"""
Centralized function to convert fp32 model (D2Go's MetaArch) to fake quant model.
"""
......@@ -296,20 +300,20 @@ def convert_to_fake_quant_model(cfg, model, example_input=None):
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model)
# NOTE: eager model needs to call prepare after `prepare_for_quant`
if model.training:
if is_qat:
torch.ao.quantization.prepare_qat(model, inplace=True)
else:
torch.ao.quantization.prepare(model, inplace=True)
else:
if hasattr(model, "custom_prepare_fx"):
model = model.custom_prepare_fx(cfg, example_input)
model = model.custom_prepare_fx(cfg, is_qat, example_input)
# TODO: remove this branch after completely separating the eager and FX APIs
elif hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg, example_input)
else:
logger.info("Using default implementation for custom_prepare_fx")
model = default_custom_prepare_fx(cfg, model, example_input)
model = default_custom_prepare_fx(cfg, model, is_qat, example_input)
return model
......@@ -340,7 +344,7 @@ def post_training_quantize(cfg, model, data_loader):
param.requires_grad = False
example_input = next(iter(data_loader))
model = convert_to_fake_quant_model(cfg, model, example_input)
model = convert_to_fake_quant_model(cfg, model, False, example_input)
logger.info("Prepared the PTQ model for calibration:\n{}".format(model))
# Option for forcing running calibration on GPU, works only when the model supports
......@@ -401,7 +405,7 @@ def setup_qat_model(
model_fp32_state_dict = model_fp32.state_dict()
# prepare model for qat
model = convert_to_fake_quant_model(cfg, model_fp32)
model = convert_to_fake_quant_model(cfg, model_fp32, True)
# make sure the proper qconfig are used in the model
learnable_qat.check_for_learnable_fake_quant_ops(qat_method, model)
......
......@@ -475,12 +475,14 @@ class DefaultTask(pl.LightningModule):
self.ema_state.load_state_dict(checkpointed_state["model_ema"])
rank_zero_info("Loaded EMA state from checkpoint.")
# TODO: remove custom_prepare_fx/custom_convert_fx from LightningModule
def custom_prepare_fx(self) -> pl.LightningModule:
if hasattr(self.model, "custom_prepare_fx"):
self.model = self.model.custom_prepare_fx(self.cfg, example_input=None)
else:
self.model = default_custom_prepare_fx(
self.cfg, self.model, example_input=None
self.cfg, self.model, self.model.training, example_input=None
)
return self
......
......@@ -127,3 +127,20 @@ def _log_api_usage(identifier: str):
inside facebook's infra.
"""
torch._C._log_api_usage_once("d2go." + identifier)
def inplace_delegate(self, api_name, sub_module_name, *args, **kwargs):
"""Helper function to delegate API calls to its submodule"""
sub_module = getattr(self, sub_module_name)
api_name = f"delegate_{api_name}"
if hasattr(sub_module, api_name):
func = getattr(sub_module, api_name)
# Assume the return of `func` will replace the submodule
setattr(self, sub_module_name, func(*args, **kwargs))
return self
else:
raise RuntimeError(
f"It seems the {sub_module_name} doesn't implement {api_name},"
" quantization might fail."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment