Commit 57156eea authored by Jerry Zhang's avatar Jerry Zhang Committed by Facebook GitHub Bot
Browse files

Follow up fixes for example_inputs refactor

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/321

Following up the bc-breaking change from fx graph mode quantization: https://github.com/pytorch/pytorch/pull/76496 that
added example_inputs to prepare_fx and prepare_qat_fx, we fixes the callsite related to mobile-vision and exposed
extra example_inputs in some apis

Reviewed By: wat3rBro

Differential Revision: D37163018

fbshipit-source-id: 9f0bb56659345d174a39b6d3bb4408caa553b88d
parent 4397dcbe
...@@ -32,6 +32,8 @@ from mobile_cv.predictor.api import FuncInfo ...@@ -32,6 +32,8 @@ from mobile_cv.predictor.api import FuncInfo
from torch.ao.quantization import convert from torch.ao.quantization import convert
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
# from torch.ao.quantization.utils import get_fqn_to_example_inputs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -64,6 +66,11 @@ class GeneralizedRCNN(_GeneralizedRCNN): ...@@ -64,6 +66,11 @@ class GeneralizedRCNN(_GeneralizedRCNN):
def _cast_model_to_device(self, device): def _cast_model_to_device(self, device):
return _cast_detection_model(self, device) return _cast_detection_model(self, device)
@property
def example_input(self):
# TODO[quant-example-inputs]: provide correct example_input for GeneralizedRCNN
return torch.randn(1, 3, 224, 224)
# Re-register D2's meta-arch in D2Go with updated APIs # Re-register D2's meta-arch in D2Go with updated APIs
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
...@@ -209,17 +216,30 @@ def _apply_eager_mode_quant(cfg, model): ...@@ -209,17 +216,30 @@ def _apply_eager_mode_quant(cfg, model):
return model return model
def _fx_quant_prepare(self, cfg): def _fx_quant_prepare(self, cfg, example_input):
prep_fn = prepare_qat_fx if self.training else prepare_fx prep_fn = prepare_qat_fx if self.training else prepare_fx
qconfig = {"": self.qconfig} qconfig = {"": self.qconfig}
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode" assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
# TODO[quant-example-inputs]: Expose example_inputs as argument # TODO[quant-example-inputs]: set a correct example_input and uncoment the next line
# Note: this is used in quantization for all submodules # fqn_to_example_inputs = get_fqn_to_example_inputs(self, (example_input,))
example_inputs = (torch.rand(1, 3, 3, 3),) fqn_to_example_inputs = {
"backbone": (torch.randn(1, 3, 224, 224),),
"proposal_generator.rpn_head.rpn_feature": (torch.randn(1, 3, 224, 224),),
"proposal_generator.rpn_head.rpn_regressor.cls_logits": (
torch.randn(1, 3, 224, 224),
),
"proposal_generator.rpn_head.rpn_regressor.bbox_pred": (
torch.randn(1, 3, 224, 224),
),
"roi_heads.box_head.roi_box_conv": (torch.randn(1, 3, 224, 224),),
"roi_heads.box_head.avgpool": (torch.randn(1, 3, 224, 224),),
"roi_heads.box_predictor.cls_score": (torch.randn(1, 3, 224, 224),),
"roi_heads.box_predictor.bbox_pred": (torch.randn(1, 3, 224, 224),),
}
self.backbone = prep_fn( self.backbone = prep_fn(
self.backbone, self.backbone,
qconfig, qconfig,
example_inputs, fqn_to_example_inputs["backbone"],
prepare_custom_config={ prepare_custom_config={
"preserved_attributes": ["size_divisibility", "padding_constraints"], "preserved_attributes": ["size_divisibility", "padding_constraints"],
# keep the output of backbone quantized, to avoid # keep the output of backbone quantized, to avoid
...@@ -233,7 +253,7 @@ def _fx_quant_prepare(self, cfg): ...@@ -233,7 +253,7 @@ def _fx_quant_prepare(self, cfg):
self.proposal_generator.rpn_head.rpn_feature = prep_fn( self.proposal_generator.rpn_head.rpn_feature = prep_fn(
self.proposal_generator.rpn_head.rpn_feature, self.proposal_generator.rpn_head.rpn_feature,
qconfig, qconfig,
example_inputs, fqn_to_example_inputs["proposal_generator.rpn_head.rpn_feature"],
prepare_custom_config={ prepare_custom_config={
# rpn_feature expecting quantized input, this is used to avoid redundant # rpn_feature expecting quantized input, this is used to avoid redundant
# quant # quant
...@@ -243,17 +263,17 @@ def _fx_quant_prepare(self, cfg): ...@@ -243,17 +263,17 @@ def _fx_quant_prepare(self, cfg):
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = prep_fn( self.proposal_generator.rpn_head.rpn_regressor.cls_logits = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits, self.proposal_generator.rpn_head.rpn_regressor.cls_logits,
qconfig, qconfig,
example_inputs, fqn_to_example_inputs["proposal_generator.rpn_head.rpn_regressor.cls_logits"],
) )
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = prep_fn( self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred, self.proposal_generator.rpn_head.rpn_regressor.bbox_pred,
qconfig, qconfig,
example_inputs, fqn_to_example_inputs["proposal_generator.rpn_head.rpn_regressor.bbox_pred"],
) )
self.roi_heads.box_head.roi_box_conv = prep_fn( self.roi_heads.box_head.roi_box_conv = prep_fn(
self.roi_heads.box_head.roi_box_conv, self.roi_heads.box_head.roi_box_conv,
qconfig, qconfig,
example_inputs, fqn_to_example_inputs["roi_heads.box_head.roi_box_conv"],
prepare_custom_config={ prepare_custom_config={
"output_quantized_idxs": [0], "output_quantized_idxs": [0],
}, },
...@@ -261,25 +281,25 @@ def _fx_quant_prepare(self, cfg): ...@@ -261,25 +281,25 @@ def _fx_quant_prepare(self, cfg):
self.roi_heads.box_head.avgpool = prep_fn( self.roi_heads.box_head.avgpool = prep_fn(
self.roi_heads.box_head.avgpool, self.roi_heads.box_head.avgpool,
qconfig, qconfig,
example_inputs, fqn_to_example_inputs["roi_heads.box_head.avgpool"],
prepare_custom_config={"input_quantized_idxs": [0]}, prepare_custom_config={"input_quantized_idxs": [0]},
) )
self.roi_heads.box_predictor.cls_score = prep_fn( self.roi_heads.box_predictor.cls_score = prep_fn(
self.roi_heads.box_predictor.cls_score, self.roi_heads.box_predictor.cls_score,
qconfig, qconfig,
example_inputs, fqn_to_example_inputs["roi_heads.box_predictor.cls_score"],
prepare_custom_config={"input_quantized_idxs": [0]}, prepare_custom_config={"input_quantized_idxs": [0]},
) )
self.roi_heads.box_predictor.bbox_pred = prep_fn( self.roi_heads.box_predictor.bbox_pred = prep_fn(
self.roi_heads.box_predictor.bbox_pred, self.roi_heads.box_predictor.bbox_pred,
qconfig, qconfig,
example_inputs, fqn_to_example_inputs["roi_heads.box_predictor.bbox_pred"],
prepare_custom_config={"input_quantized_idxs": [0]}, prepare_custom_config={"input_quantized_idxs": [0]},
) )
@RCNN_PREPARE_FOR_QUANT_REGISTRY.register() @RCNN_PREPARE_FOR_QUANT_REGISTRY.register()
def default_rcnn_prepare_for_quant(self, cfg): def default_rcnn_prepare_for_quant(self, cfg, example_input=None):
model = self model = self
model.qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training) model.qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
if ( if (
...@@ -299,7 +319,9 @@ def default_rcnn_prepare_for_quant(self, cfg): ...@@ -299,7 +319,9 @@ def default_rcnn_prepare_for_quant(self, cfg):
inplace=True, inplace=True,
) )
else: else:
_fx_quant_prepare(model, cfg) if example_input is None:
example_input = model.example_input
_fx_quant_prepare(model, cfg, example_input)
return model return model
......
...@@ -211,7 +211,8 @@ def mock_quantization_type(quant_func): ...@@ -211,7 +211,8 @@ def mock_quantization_type(quant_func):
return wrapper return wrapper
def default_prepare_for_quant(cfg, model): def default_prepare_for_quant(cfg, model, example_input=None):
""" """
Default implementation of preparing a model for quantization. This function will Default implementation of preparing a model for quantization. This function will
be called to before training if QAT is enabled, or before calibration during PTQ if be called to before training if QAT is enabled, or before calibration during PTQ if
...@@ -231,6 +232,9 @@ def default_prepare_for_quant(cfg, model): ...@@ -231,6 +232,9 @@ def default_prepare_for_quant(cfg, model):
Args: Args:
model (nn.Module): a non-quantized model. model (nn.Module): a non-quantized model.
cfg (CfgNode): config cfg (CfgNode): config
example_input (Optional[Any]): optional example_input for model,
if it is not provided we'll use `model.example_input` when example_input
is required, Note: d2go assumes we always have a single example_input
Return: Return:
nn.Module: a ready model for QAT training or PTQ calibration nn.Module: a ready model for QAT training or PTQ calibration
...@@ -249,12 +253,12 @@ def default_prepare_for_quant(cfg, model): ...@@ -249,12 +253,12 @@ def default_prepare_for_quant(cfg, model):
# here, to be consistent with the FX branch # here, to be consistent with the FX branch
else: # FX graph mode quantization else: # FX graph mode quantization
qconfig_dict = {"": qconfig} qconfig_dict = {"": qconfig}
# TODO[quant-example-inputs]: needs follow up to change the api if example_input is None:
example_inputs = (torch.rand(1, 3, 3, 3),) example_input = model.example_input
if model.training: if model.training:
model = prepare_qat_fx(model, qconfig_dict, example_inputs) model = prepare_qat_fx(model, qconfig_dict, (example_input,))
else: else:
model = prepare_fx(model, qconfig_dict, example_inputs) model = prepare_fx(model, qconfig_dict, (example_input,))
logger.info("Setup the model with qconfig:\n{}".format(qconfig)) logger.info("Setup the model with qconfig:\n{}".format(qconfig))
...@@ -265,16 +269,16 @@ def default_prepare_for_quant_convert(cfg, model): ...@@ -265,16 +269,16 @@ def default_prepare_for_quant_convert(cfg, model):
return convert_fx(model) return convert_fx(model)
def apply_prepare_for_quant(cfg, model): def apply_prepare_for_quant(cfg, model, example_input=None):
# TODO: create a warning for the direct use of `torch.ao.quantization.get_default_qconfig` # TODO: create a warning for the direct use of `torch.ao.quantization.get_default_qconfig`
# or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level # or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level
# `set_backend_and_create_qconfig` API. # `set_backend_and_create_qconfig` API.
if hasattr(model, "prepare_for_quant"): if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg) model = model.prepare_for_quant(cfg, example_input)
else: else:
logger.info("Using default implementation for prepare_for_quant") logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model) model = default_prepare_for_quant(cfg, model, example_input)
return model return model
...@@ -288,7 +292,8 @@ def post_training_quantize(cfg, model, data_loader): ...@@ -288,7 +292,8 @@ def post_training_quantize(cfg, model, data_loader):
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
model = apply_prepare_for_quant(cfg, model) example_input = next(iter(data_loader))
model = apply_prepare_for_quant(cfg, model, example_input)
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
torch.ao.quantization.prepare(model, inplace=True) torch.ao.quantization.prepare(model, inplace=True)
logger.info("Prepared the PTQ model for calibration:\n{}".format(model)) logger.info("Prepared the PTQ model for calibration:\n{}".format(model))
......
...@@ -21,7 +21,7 @@ from torch.ao.quantization import ( # @manual ...@@ -21,7 +21,7 @@ from torch.ao.quantization import ( # @manual
QuantType, QuantType,
) )
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from torch.ao.quantization.utils import get_quant_type from torch.ao.quantization.utils import get_fqn_to_example_inputs, get_quant_type
QConfigDicts = Dict[str, Dict[str, Union[QConfig, QConfigDynamic]]] QConfigDicts = Dict[str, Dict[str, Union[QConfig, QConfigDynamic]]]
...@@ -172,16 +172,19 @@ class QuantizationMixin(ABC): ...@@ -172,16 +172,19 @@ class QuantizationMixin(ABC):
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr) attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
} }
prepared = root prepared = root
# TODO[quant-example-inputs]: expose example_inputs as argument
# may need a dictionary that stores a map from submodule fqn to example_inputs
# for submodule
example_inputs = (torch.rand(1, 3, 3, 3),)
if "" in configs: if "" in configs:
prepared = prep_fn(root, configs[""], example_inputs) prepared = prep_fn(root, configs[""], root.example_input_array)
else: else:
fqn_to_example_inputs = get_fqn_to_example_inputs(
root, root.example_input_array
)
for name, config in configs.items(): for name, config in configs.items():
submodule = rgetattr(root, name) submodule = rgetattr(root, name)
rsetattr(root, name, prep_fn(submodule, config, example_inputs)) rsetattr(
root, name, prep_fn(submodule, config, fqn_to_example_inputs[name])
)
for attr, value in old_attrs.items(): for attr, value in old_attrs.items():
rsetattr(prepared, attr, value) rsetattr(prepared, attr, value)
return prepared return prepared
......
...@@ -476,10 +476,11 @@ class DefaultTask(pl.LightningModule): ...@@ -476,10 +476,11 @@ class DefaultTask(pl.LightningModule):
rank_zero_info("Loaded EMA state from checkpoint.") rank_zero_info("Loaded EMA state from checkpoint.")
def prepare_for_quant(self) -> pl.LightningModule: def prepare_for_quant(self) -> pl.LightningModule:
example_input = self.model.example_input
if hasattr(self.model, "prepare_for_quant"): if hasattr(self.model, "prepare_for_quant"):
self.model = self.model.prepare_for_quant(self.cfg) self.model = self.model.prepare_for_quant(self.cfg, example_input)
else: else:
self.model = default_prepare_for_quant(self.cfg, self.model) self.model = default_prepare_for_quant(self.cfg, self.model, example_input)
return self return self
def prepare_for_quant_convert(self) -> pl.LightningModule: def prepare_for_quant_convert(self) -> pl.LightningModule:
......
...@@ -26,6 +26,11 @@ class DetMetaArchForTest(torch.nn.Module): ...@@ -26,6 +26,11 @@ class DetMetaArchForTest(torch.nn.Module):
def device(self): def device(self):
return self.conv.weight.device return self.conv.weight.device
@property
def example_input(self):
# TODO[quant-example-inputs]: set example_input properly
return torch.randn(1, 3, 224, 224)
def forward(self, inputs): def forward(self, inputs):
if not self.training: if not self.training:
return self.inference(inputs) return self.inference(inputs)
...@@ -52,7 +57,8 @@ class DetMetaArchForTest(torch.nn.Module): ...@@ -52,7 +57,8 @@ class DetMetaArchForTest(torch.nn.Module):
ret = [{"instances": instance}] ret = [{"instances": instance}]
return ret return ret
def prepare_for_quant(self, cfg): def prepare_for_quant(self, cfg, example_input=None):
# TODO[quant-example-inputs]: use example_input
example_inputs = (torch.rand(1, 3, 3, 3),) example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx( self.avgpool = prepare_qat_fx(
self.avgpool, self.avgpool,
......
...@@ -228,7 +228,7 @@ class TestDefaultRunner(unittest.TestCase): ...@@ -228,7 +228,7 @@ class TestDefaultRunner(unittest.TestCase):
@META_ARCH_REGISTRY.register() @META_ARCH_REGISTRY.register()
class MetaArchForTestQAT(MetaArchForTest): class MetaArchForTestQAT(MetaArchForTest):
def prepare_for_quant(self, cfg): def prepare_for_quant(self, cfg, example_inputs=None):
"""Set the qconfig to updateable observers""" """Set the qconfig to updateable observers"""
self.qconfig = updateable_symmetric_moving_avg_minmax_config self.qconfig = updateable_symmetric_moving_avg_minmax_config
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment