Commit 57156eea authored by Jerry Zhang's avatar Jerry Zhang Committed by Facebook GitHub Bot
Browse files

Follow up fixes for example_inputs refactor

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/321

Following up the bc-breaking change from fx graph mode quantization: https://github.com/pytorch/pytorch/pull/76496 that
added example_inputs to prepare_fx and prepare_qat_fx, we fixes the callsite related to mobile-vision and exposed
extra example_inputs in some apis

Reviewed By: wat3rBro

Differential Revision: D37163018

fbshipit-source-id: 9f0bb56659345d174a39b6d3bb4408caa553b88d
parent 4397dcbe
......@@ -32,6 +32,8 @@ from mobile_cv.predictor.api import FuncInfo
from torch.ao.quantization import convert
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
# from torch.ao.quantization.utils import get_fqn_to_example_inputs
logger = logging.getLogger(__name__)
......@@ -64,6 +66,11 @@ class GeneralizedRCNN(_GeneralizedRCNN):
def _cast_model_to_device(self, device):
return _cast_detection_model(self, device)
@property
def example_input(self):
# TODO[quant-example-inputs]: provide correct example_input for GeneralizedRCNN
return torch.randn(1, 3, 224, 224)
# Re-register D2's meta-arch in D2Go with updated APIs
@META_ARCH_REGISTRY.register()
......@@ -209,17 +216,30 @@ def _apply_eager_mode_quant(cfg, model):
return model
def _fx_quant_prepare(self, cfg):
def _fx_quant_prepare(self, cfg, example_input):
prep_fn = prepare_qat_fx if self.training else prepare_fx
qconfig = {"": self.qconfig}
assert not isinstance(self.backbone, FPN), "FPN is not supported in FX mode"
# TODO[quant-example-inputs]: Expose example_inputs as argument
# Note: this is used in quantization for all submodules
example_inputs = (torch.rand(1, 3, 3, 3),)
# TODO[quant-example-inputs]: set a correct example_input and uncoment the next line
# fqn_to_example_inputs = get_fqn_to_example_inputs(self, (example_input,))
fqn_to_example_inputs = {
"backbone": (torch.randn(1, 3, 224, 224),),
"proposal_generator.rpn_head.rpn_feature": (torch.randn(1, 3, 224, 224),),
"proposal_generator.rpn_head.rpn_regressor.cls_logits": (
torch.randn(1, 3, 224, 224),
),
"proposal_generator.rpn_head.rpn_regressor.bbox_pred": (
torch.randn(1, 3, 224, 224),
),
"roi_heads.box_head.roi_box_conv": (torch.randn(1, 3, 224, 224),),
"roi_heads.box_head.avgpool": (torch.randn(1, 3, 224, 224),),
"roi_heads.box_predictor.cls_score": (torch.randn(1, 3, 224, 224),),
"roi_heads.box_predictor.bbox_pred": (torch.randn(1, 3, 224, 224),),
}
self.backbone = prep_fn(
self.backbone,
qconfig,
example_inputs,
fqn_to_example_inputs["backbone"],
prepare_custom_config={
"preserved_attributes": ["size_divisibility", "padding_constraints"],
# keep the output of backbone quantized, to avoid
......@@ -233,7 +253,7 @@ def _fx_quant_prepare(self, cfg):
self.proposal_generator.rpn_head.rpn_feature = prep_fn(
self.proposal_generator.rpn_head.rpn_feature,
qconfig,
example_inputs,
fqn_to_example_inputs["proposal_generator.rpn_head.rpn_feature"],
prepare_custom_config={
# rpn_feature expecting quantized input, this is used to avoid redundant
# quant
......@@ -243,17 +263,17 @@ def _fx_quant_prepare(self, cfg):
self.proposal_generator.rpn_head.rpn_regressor.cls_logits = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.cls_logits,
qconfig,
example_inputs,
fqn_to_example_inputs["proposal_generator.rpn_head.rpn_regressor.cls_logits"],
)
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred = prep_fn(
self.proposal_generator.rpn_head.rpn_regressor.bbox_pred,
qconfig,
example_inputs,
fqn_to_example_inputs["proposal_generator.rpn_head.rpn_regressor.bbox_pred"],
)
self.roi_heads.box_head.roi_box_conv = prep_fn(
self.roi_heads.box_head.roi_box_conv,
qconfig,
example_inputs,
fqn_to_example_inputs["roi_heads.box_head.roi_box_conv"],
prepare_custom_config={
"output_quantized_idxs": [0],
},
......@@ -261,25 +281,25 @@ def _fx_quant_prepare(self, cfg):
self.roi_heads.box_head.avgpool = prep_fn(
self.roi_heads.box_head.avgpool,
qconfig,
example_inputs,
fqn_to_example_inputs["roi_heads.box_head.avgpool"],
prepare_custom_config={"input_quantized_idxs": [0]},
)
self.roi_heads.box_predictor.cls_score = prep_fn(
self.roi_heads.box_predictor.cls_score,
qconfig,
example_inputs,
fqn_to_example_inputs["roi_heads.box_predictor.cls_score"],
prepare_custom_config={"input_quantized_idxs": [0]},
)
self.roi_heads.box_predictor.bbox_pred = prep_fn(
self.roi_heads.box_predictor.bbox_pred,
qconfig,
example_inputs,
fqn_to_example_inputs["roi_heads.box_predictor.bbox_pred"],
prepare_custom_config={"input_quantized_idxs": [0]},
)
@RCNN_PREPARE_FOR_QUANT_REGISTRY.register()
def default_rcnn_prepare_for_quant(self, cfg):
def default_rcnn_prepare_for_quant(self, cfg, example_input=None):
model = self
model.qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
if (
......@@ -299,7 +319,9 @@ def default_rcnn_prepare_for_quant(self, cfg):
inplace=True,
)
else:
_fx_quant_prepare(model, cfg)
if example_input is None:
example_input = model.example_input
_fx_quant_prepare(model, cfg, example_input)
return model
......
......@@ -211,7 +211,8 @@ def mock_quantization_type(quant_func):
return wrapper
def default_prepare_for_quant(cfg, model):
def default_prepare_for_quant(cfg, model, example_input=None):
"""
Default implementation of preparing a model for quantization. This function will
be called to before training if QAT is enabled, or before calibration during PTQ if
......@@ -231,6 +232,9 @@ def default_prepare_for_quant(cfg, model):
Args:
model (nn.Module): a non-quantized model.
cfg (CfgNode): config
example_input (Optional[Any]): optional example_input for model,
if it is not provided we'll use `model.example_input` when example_input
is required, Note: d2go assumes we always have a single example_input
Return:
nn.Module: a ready model for QAT training or PTQ calibration
......@@ -249,12 +253,12 @@ def default_prepare_for_quant(cfg, model):
# here, to be consistent with the FX branch
else: # FX graph mode quantization
qconfig_dict = {"": qconfig}
# TODO[quant-example-inputs]: needs follow up to change the api
example_inputs = (torch.rand(1, 3, 3, 3),)
if example_input is None:
example_input = model.example_input
if model.training:
model = prepare_qat_fx(model, qconfig_dict, example_inputs)
model = prepare_qat_fx(model, qconfig_dict, (example_input,))
else:
model = prepare_fx(model, qconfig_dict, example_inputs)
model = prepare_fx(model, qconfig_dict, (example_input,))
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
......@@ -265,16 +269,16 @@ def default_prepare_for_quant_convert(cfg, model):
return convert_fx(model)
def apply_prepare_for_quant(cfg, model):
def apply_prepare_for_quant(cfg, model, example_input=None):
# TODO: create a warning for the direct use of `torch.ao.quantization.get_default_qconfig`
# or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level
# `set_backend_and_create_qconfig` API.
if hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg)
model = model.prepare_for_quant(cfg, example_input)
else:
logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model)
model = default_prepare_for_quant(cfg, model, example_input)
return model
......@@ -288,7 +292,8 @@ def post_training_quantize(cfg, model, data_loader):
for param in model.parameters():
param.requires_grad = False
model = apply_prepare_for_quant(cfg, model)
example_input = next(iter(data_loader))
model = apply_prepare_for_quant(cfg, model, example_input)
if cfg.QUANTIZATION.EAGER_MODE:
torch.ao.quantization.prepare(model, inplace=True)
logger.info("Prepared the PTQ model for calibration:\n{}".format(model))
......
......@@ -21,7 +21,7 @@ from torch.ao.quantization import ( # @manual
QuantType,
)
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx, prepare_qat_fx
from torch.ao.quantization.utils import get_quant_type
from torch.ao.quantization.utils import get_fqn_to_example_inputs, get_quant_type
QConfigDicts = Dict[str, Dict[str, Union[QConfig, QConfigDynamic]]]
......@@ -172,16 +172,19 @@ class QuantizationMixin(ABC):
attr: rgetattr(root, attr) for attr in attrs if rhasattr(root, attr)
}
prepared = root
# TODO[quant-example-inputs]: expose example_inputs as argument
# may need a dictionary that stores a map from submodule fqn to example_inputs
# for submodule
example_inputs = (torch.rand(1, 3, 3, 3),)
if "" in configs:
prepared = prep_fn(root, configs[""], example_inputs)
prepared = prep_fn(root, configs[""], root.example_input_array)
else:
fqn_to_example_inputs = get_fqn_to_example_inputs(
root, root.example_input_array
)
for name, config in configs.items():
submodule = rgetattr(root, name)
rsetattr(root, name, prep_fn(submodule, config, example_inputs))
rsetattr(
root, name, prep_fn(submodule, config, fqn_to_example_inputs[name])
)
for attr, value in old_attrs.items():
rsetattr(prepared, attr, value)
return prepared
......
......@@ -476,10 +476,11 @@ class DefaultTask(pl.LightningModule):
rank_zero_info("Loaded EMA state from checkpoint.")
def prepare_for_quant(self) -> pl.LightningModule:
example_input = self.model.example_input
if hasattr(self.model, "prepare_for_quant"):
self.model = self.model.prepare_for_quant(self.cfg)
self.model = self.model.prepare_for_quant(self.cfg, example_input)
else:
self.model = default_prepare_for_quant(self.cfg, self.model)
self.model = default_prepare_for_quant(self.cfg, self.model, example_input)
return self
def prepare_for_quant_convert(self) -> pl.LightningModule:
......
......@@ -26,6 +26,11 @@ class DetMetaArchForTest(torch.nn.Module):
def device(self):
return self.conv.weight.device
@property
def example_input(self):
# TODO[quant-example-inputs]: set example_input properly
return torch.randn(1, 3, 224, 224)
def forward(self, inputs):
if not self.training:
return self.inference(inputs)
......@@ -52,7 +57,8 @@ class DetMetaArchForTest(torch.nn.Module):
ret = [{"instances": instance}]
return ret
def prepare_for_quant(self, cfg):
def prepare_for_quant(self, cfg, example_input=None):
# TODO[quant-example-inputs]: use example_input
example_inputs = (torch.rand(1, 3, 3, 3),)
self.avgpool = prepare_qat_fx(
self.avgpool,
......
......@@ -228,7 +228,7 @@ class TestDefaultRunner(unittest.TestCase):
@META_ARCH_REGISTRY.register()
class MetaArchForTestQAT(MetaArchForTest):
def prepare_for_quant(self, cfg):
def prepare_for_quant(self, cfg, example_inputs=None):
"""Set the qconfig to updateable observers"""
self.qconfig = updateable_symmetric_moving_avg_minmax_config
return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment