Commit 74bc35ea authored by Yanghan Wang's avatar Yanghan Wang Committed by Facebook GitHub Bot
Browse files

split default_prepare_for_quant into eager and FX mode

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/331

- remove `example_input` from `default_prepare_for_quant` since now it's eager mode only.
- rename `apply_prepare_for_quant` to `convert_to_fake_quant_model` to reflect what it's doing.

Reviewed By: jerryzh168

Differential Revision: D37794085

fbshipit-source-id: e6f12098976272d979a3aac66287d9b95432dcc8
parent 6d152388
...@@ -213,7 +213,7 @@ def mock_quantization_type(quant_func): ...@@ -213,7 +213,7 @@ def mock_quantization_type(quant_func):
return wrapper return wrapper
def default_prepare_for_quant(cfg, model, example_input=None): def default_prepare_for_quant(cfg, model):
""" """
Default implementation of preparing a model for quantization. This function will Default implementation of preparing a model for quantization. This function will
...@@ -234,26 +234,37 @@ def default_prepare_for_quant(cfg, model, example_input=None): ...@@ -234,26 +234,37 @@ def default_prepare_for_quant(cfg, model, example_input=None):
Args: Args:
model (nn.Module): a non-quantized model. model (nn.Module): a non-quantized model.
cfg (CfgNode): config cfg (CfgNode): config
example_input (Optional[Any]): optional example_input for model,
if it is not provided we'll use `model.example_input` when example_input
is required, Note: d2go assumes we always have a single example_input
Return: Return:
nn.Module: a ready model for QAT training or PTQ calibration nn.Module: a ready model for QAT training or PTQ calibration
""" """
assert cfg.QUANTIZATION.EAGER_MODE
qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training) qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
if cfg.QUANTIZATION.EAGER_MODE:
model = fuse_utils.fuse_model( model = fuse_utils.fuse_model(
model, model,
is_qat=cfg.QUANTIZATION.QAT.ENABLED, is_qat=cfg.QUANTIZATION.QAT.ENABLED,
inplace=True, inplace=True,
) )
model.qconfig = qconfig model.qconfig = qconfig
# TODO(future diff): move the torch.ao.quantization.prepare(...) call # TODO(future diff): move the torch.ao.quantization.prepare(...) call
# here, to be consistent with the FX branch # here, to be consistent with the FX branch
else: # FX graph mode quantization
logger.info("Setup the model with qconfig:\n{}".format(qconfig))
return model
def default_custom_prepare_fx(cfg, model, example_input=None):
"""
Similar to default_prepare_for_quant, but for FX graph mode.
Args:
example_input (Optional[Any]): optional example_input for model,
if it is not provided we'll use `model.example_input` when example_input
is required, Note: d2go assumes we always have a single example_input
"""
assert not cfg.QUANTIZATION.EAGER_MODE
qconfig = set_backend_and_create_qconfig(cfg, is_train=model.training)
qconfig_dict = {"": qconfig} qconfig_dict = {"": qconfig}
if example_input is None: if example_input is None:
example_input = model.example_input example_input = model.example_input
...@@ -263,7 +274,6 @@ def default_prepare_for_quant(cfg, model, example_input=None): ...@@ -263,7 +274,6 @@ def default_prepare_for_quant(cfg, model, example_input=None):
model = prepare_fx(model, qconfig_dict, (example_input,)) model = prepare_fx(model, qconfig_dict, (example_input,))
logger.info("Setup the model with qconfig:\n{}".format(qconfig)) logger.info("Setup the model with qconfig:\n{}".format(qconfig))
return model return model
...@@ -271,7 +281,10 @@ def default_custom_convert_fx(cfg, model): ...@@ -271,7 +281,10 @@ def default_custom_convert_fx(cfg, model):
return convert_fx(model) return convert_fx(model)
def apply_prepare_for_quant(cfg, model, example_input=None): def convert_to_fake_quant_model(cfg, model, example_input=None):
"""
Centralized function to convert fp32 model (D2Go's MetaArch) to fake quant model.
"""
# TODO: create a warning for the direct use of `torch.ao.quantization.get_default_qconfig` # TODO: create a warning for the direct use of `torch.ao.quantization.get_default_qconfig`
# or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level # or `torch.ao.quantization.get_default_qat_qconfig` without calling D2Go's high-level
# `set_backend_and_create_qconfig` API. # `set_backend_and_create_qconfig` API.
...@@ -281,7 +294,7 @@ def apply_prepare_for_quant(cfg, model, example_input=None): ...@@ -281,7 +294,7 @@ def apply_prepare_for_quant(cfg, model, example_input=None):
model = model.prepare_for_quant(cfg) model = model.prepare_for_quant(cfg)
else: else:
logger.info("Using default implementation for prepare_for_quant") logger.info("Using default implementation for prepare_for_quant")
model = default_prepare_for_quant(cfg, model, example_input) model = default_prepare_for_quant(cfg, model)
# NOTE: eager model needs to call prepare after `prepare_for_quant` # NOTE: eager model needs to call prepare after `prepare_for_quant`
if model.training: if model.training:
torch.ao.quantization.prepare_qat(model, inplace=True) torch.ao.quantization.prepare_qat(model, inplace=True)
...@@ -295,15 +308,16 @@ def apply_prepare_for_quant(cfg, model, example_input=None): ...@@ -295,15 +308,16 @@ def apply_prepare_for_quant(cfg, model, example_input=None):
elif hasattr(model, "prepare_for_quant"): elif hasattr(model, "prepare_for_quant"):
model = model.prepare_for_quant(cfg, example_input) model = model.prepare_for_quant(cfg, example_input)
else: else:
logger.info("Using default implementation for prepare_for_quant") logger.info("Using default implementation for custom_prepare_fx")
model = default_prepare_for_quant(cfg, model, example_input) model = default_custom_prepare_fx(cfg, model, example_input)
return model return model
def convert_to_quantized_model(cfg, fp32_model): def convert_to_quantized_model(cfg, fp32_model):
""" """
Convert fake quant model (fp32 operators) to "real" quantized model (int8 operators) Contralized function to convert fake quant model (fp32 operators) to "real"
quantized model (int8 operators).
""" """
if cfg.QUANTIZATION.EAGER_MODE: if cfg.QUANTIZATION.EAGER_MODE:
int8_model = convert(fp32_model, inplace=False) int8_model = convert(fp32_model, inplace=False)
...@@ -326,7 +340,7 @@ def post_training_quantize(cfg, model, data_loader): ...@@ -326,7 +340,7 @@ def post_training_quantize(cfg, model, data_loader):
param.requires_grad = False param.requires_grad = False
example_input = next(iter(data_loader)) example_input = next(iter(data_loader))
model = apply_prepare_for_quant(cfg, model, example_input) model = convert_to_fake_quant_model(cfg, model, example_input)
logger.info("Prepared the PTQ model for calibration:\n{}".format(model)) logger.info("Prepared the PTQ model for calibration:\n{}".format(model))
# Option for forcing running calibration on GPU, works only when the model supports # Option for forcing running calibration on GPU, works only when the model supports
...@@ -387,7 +401,7 @@ def setup_qat_model( ...@@ -387,7 +401,7 @@ def setup_qat_model(
model_fp32_state_dict = model_fp32.state_dict() model_fp32_state_dict = model_fp32.state_dict()
# prepare model for qat # prepare model for qat
model = apply_prepare_for_quant(cfg, model_fp32) model = convert_to_fake_quant_model(cfg, model_fp32)
# make sure the proper qconfig are used in the model # make sure the proper qconfig are used in the model
learnable_qat.check_for_learnable_fake_quant_ops(qat_method, model) learnable_qat.check_for_learnable_fake_quant_ops(qat_method, model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment